diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7fb01708a3d..09e92a66759 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,13 +50,13 @@ repos: entry: bash .pre-commit-hooks/clang-format.hook -i language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ - exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$ + exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$ - id: copyright_checker name: copyright_checker entry: python .pre-commit-hooks/copyright-check.hook language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ - exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$ + exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$ - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e8315e76c7..62fead47015 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,13 @@ # Changelog +Date: 2022-3-08, Author: yt605155624. +Add features to: T2S: + - Add aishell3 hifigan egs. + - PRLink: https://github.com/PaddlePaddle/PaddleSpeech/pull/1545 + +Date: 2022-3-08, Author: yt605155624. +Add features to: T2S: + - Add vctk hifigan egs. + - PRLink: https://github.com/PaddlePaddle/PaddleSpeech/pull/1544 Date: 2022-1-29, Author: yt605155624. Add features to: T2S: diff --git a/README.md b/README.md index 46f492e9980..ceef15af62c 100644 --- a/README.md +++ b/README.md @@ -178,7 +178,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision -- 🤗 2021.12.14: Our PaddleSpeech [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/akhaliq/paddlespeech) Demos on Hugging Face Spaces are available! +- 🤗 2021.12.14: Our PaddleSpeech [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS) Demos on Hugging Face Spaces are available! - 👏🏻 2021.12.10: PaddleSpeech CLI is available for Audio Classification, Automatic Speech Recognition, Speech Translation (English to Chinese) and Text-to-Speech. ### Community @@ -207,6 +207,7 @@ paddlespeech cls --input input.wav ```shell paddlespeech asr --lang zh --input input_16k.wav ``` +- web demo for Automatic Speech Recognition is integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See Demo: [ASR Demo](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) **Speech Translation** (English to Chinese) (not support for Mac and Windows now) @@ -218,7 +219,7 @@ paddlespeech st --input input_16k.wav ```shell paddlespeech tts --input "你好,欢迎使用飞桨深度学习框架!" --output output.wav ``` -- web demo for Text to Speech is integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See Demo: [TTS Demo](https://huggingface.co/spaces/akhaliq/paddlespeech) +- web demo for Text to Speech is integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See Demo: [TTS Demo](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS) **Text Postprocessing** - Punctuation Restoration @@ -397,9 +398,9 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r HiFiGAN - CSMSC + LJSpeech / VCTK / CSMSC / AISHELL-3 - HiFiGAN-csmsc + HiFiGAN-ljspeech / HiFiGAN-vctk / HiFiGAN-csmsc / HiFiGAN-aishell3 @@ -573,7 +574,6 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P - Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) for years of attention, constructive advice and great help. -- Many thanks to [AK391](https://github.com/AK391) for TTS web demo on Huggingface Spaces using Gradio. - Many thanks to [mymagicpower](https://github.com/mymagicpower) for the Java implementation of ASR upon [short](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk) and [long](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk) audio files. - Many thanks to [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) for developing Virtual Uploader(VUP)/Virtual YouTuber(VTuber) with PaddleSpeech TTS function. - Many thanks to [745165806](https://github.com/745165806)/[PaddleSpeechTask](https://github.com/745165806/PaddleSpeechTask) for contributing Punctuation Restoration model. diff --git a/README_cn.md b/README_cn.md index e8494737299..8ea91e98d42 100644 --- a/README_cn.md +++ b/README_cn.md @@ -392,9 +392,9 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 HiFiGAN - CSMSC + LJSpeech / VCTK / CSMSC / AISHELL-3 - HiFiGAN-csmsc + HiFiGAN-ljspeech / HiFiGAN-vctk / HiFiGAN-csmsc / HiFiGAN-aishell3 diff --git a/demos/speech_recognition/README.md b/demos/speech_recognition/README.md index 5d964fceac7..636548801b4 100644 --- a/demos/speech_recognition/README.md +++ b/demos/speech_recognition/README.md @@ -84,5 +84,8 @@ Here is a list of pretrained models released by PaddleSpeech that can be used by | Model | Language | Sample Rate | :--- | :---: | :---: | -| conformer_wenetspeech| zh| 16000 -| transformer_librispeech| en| 16000 +| conformer_wenetspeech| zh| 16k +| transformer_librispeech| en| 16k +| deepspeech2offline_aishell| zh| 16k +| deepspeech2online_aishell | zh | 16k +|deepspeech2offline_librispeech|en| 16k diff --git a/demos/speech_recognition/README_cn.md b/demos/speech_recognition/README_cn.md index ba1f1d65c5c..8033dbd8130 100644 --- a/demos/speech_recognition/README_cn.md +++ b/demos/speech_recognition/README_cn.md @@ -81,5 +81,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee | 模型 | 语言 | 采样率 | :--- | :---: | :---: | -| conformer_wenetspeech| zh| 16000 -| transformer_librispeech| en| 16000 +| conformer_wenetspeech | zh | 16k +| transformer_librispeech | en | 16k +| deepspeech2offline_aishell| zh| 16k +| deepspeech2online_aishell | zh | 16k +| deepspeech2offline_librispeech | en | 16k diff --git a/demos/speech_server/.gitignore b/demos/speech_server/.gitignore new file mode 100644 index 00000000000..d8dd7532abc --- /dev/null +++ b/demos/speech_server/.gitignore @@ -0,0 +1 @@ +*.wav diff --git a/demos/speech_server/README.md b/demos/speech_server/README.md index a2f6f221320..10489e71314 100644 --- a/demos/speech_server/README.md +++ b/demos/speech_server/README.md @@ -110,21 +110,22 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee - Python API ```python from paddlespeech.server.bin.paddlespeech_client import ASRClientExecutor + import json asrclient_executor = ASRClientExecutor() - asrclient_executor( + res = asrclient_executor( input="./zh.wav", server_ip="127.0.0.1", port=8090, sample_rate=16000, lang="zh_cn", audio_format="wav") + print(res.json()) ``` Output: ```bash {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'transcription': '我认为跑步最重要的就是给我带来了身体健康'}} - time cost 0.604353 s. ``` ### 5. TTS Client Usage @@ -146,7 +147,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee - `speed`: Audio speed, the value should be set between 0 and 3. Default: 1.0 - `volume`: Audio volume, the value should be set between 0 and 3. Default: 1.0 - `sample_rate`: Sampling rate, choice: [0, 8000, 16000], the default is the same as the model. Default: 0 - - `output`: Output wave filepath. Default: `output.wav`. + - `output`: Output wave filepath. Default: None, which means not to save the audio to the local. Output: ```bash @@ -160,9 +161,10 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee - Python API ```python from paddlespeech.server.bin.paddlespeech_client import TTSClientExecutor + import json ttsclient_executor = TTSClientExecutor() - ttsclient_executor( + res = ttsclient_executor( input="您好,欢迎使用百度飞桨语音合成服务。", server_ip="127.0.0.1", port=8090, @@ -171,6 +173,11 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee volume=1.0, sample_rate=0, output="./output.wav") + + response_dict = res.json() + print(response_dict["message"]) + print("Save synthesized audio successfully on %s." % (response_dict['result']['save_path'])) + print("Audio duration: %f s." %(response_dict['result']['duration'])) ``` Output: @@ -178,7 +185,52 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee {'description': 'success.'} Save synthesized audio successfully on ./output.wav. Audio duration: 3.612500 s. - Response time: 0.388317 s. + + ``` + +### 6. CLS Client Usage +**Note:** The response time will be slightly longer when using the client for the first time +- Command Line (Recommended) + ``` + paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input ./zh.wav + ``` + + Usage: + + ```bash + paddlespeech_client cls --help + ``` + Arguments: + - `server_ip`: server ip. Default: 127.0.0.1 + - `port`: server port. Default: 8090 + - `input`(required): Audio file to be classified. + - `topk`: topk scores of classification result. + + Output: + ```bash + [2022-03-09 20:44:39,974] [ INFO] - {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'topk': 1, 'results': [{'class_name': 'Speech', 'prob': 0.9027184844017029}]}} + [2022-03-09 20:44:39,975] [ INFO] - Response time 0.104360 s. + + + ``` + +- Python API + ```python + from paddlespeech.server.bin.paddlespeech_client import CLSClientExecutor + import json + + clsclient_executor = CLSClientExecutor() + res = clsclient_executor( + input="./zh.wav", + server_ip="127.0.0.1", + port=8090, + topk=1) + print(res.json()) + ``` + + Output: + ```bash + {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'topk': 1, 'results': [{'class_name': 'Speech', 'prob': 0.9027184844017029}]}} ``` @@ -189,3 +241,6 @@ Get all models supported by the ASR service via `paddlespeech_server stats --tas ### TTS model Get all models supported by the TTS service via `paddlespeech_server stats --task tts`, where static models can be used for paddle inference inference. + +### CLS model +Get all models supported by the CLS service via `paddlespeech_server stats --task cls`, where static models can be used for paddle inference inference. diff --git a/demos/speech_server/README_cn.md b/demos/speech_server/README_cn.md index 762248a117f..2bd8af6c91f 100644 --- a/demos/speech_server/README_cn.md +++ b/demos/speech_server/README_cn.md @@ -80,7 +80,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee ``` -### 4. ASR客户端使用方法 +### 4. ASR 客户端使用方法 **注意:** 初次使用客户端时响应时间会略长 - 命令行 (推荐使用) ``` @@ -111,25 +111,26 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee - Python API ```python from paddlespeech.server.bin.paddlespeech_client import ASRClientExecutor + import json asrclient_executor = ASRClientExecutor() - asrclient_executor( + res = asrclient_executor( input="./zh.wav", server_ip="127.0.0.1", port=8090, sample_rate=16000, lang="zh_cn", audio_format="wav") + print(res.json()) ``` 输出: ```bash {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'transcription': '我认为跑步最重要的就是给我带来了身体健康'}} - time cost 0.604353 s. ``` -### 5. TTS客户端使用方法 +### 5. TTS 客户端使用方法 **注意:** 初次使用客户端时响应时间会略长 - 命令行 (推荐使用) @@ -150,7 +151,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee - `speed`: 音频速度,该值应设置在 0 到 3 之间。 默认值:1.0 - `volume`: 音频音量,该值应设置在 0 到 3 之间。 默认值: 1.0 - `sample_rate`: 采样率,可选 [0, 8000, 16000],默认与模型相同。 默认值:0 - - `output`: 输出音频的路径, 默认值:output.wav。 + - `output`: 输出音频的路径, 默认值:None,表示不保存音频到本地。 输出: ```bash @@ -163,9 +164,10 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee - Python API ```python from paddlespeech.server.bin.paddlespeech_client import TTSClientExecutor + import json ttsclient_executor = TTSClientExecutor() - ttsclient_executor( + res = ttsclient_executor( input="您好,欢迎使用百度飞桨语音合成服务。", server_ip="127.0.0.1", port=8090, @@ -174,6 +176,11 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee volume=1.0, sample_rate=0, output="./output.wav") + + response_dict = res.json() + print(response_dict["message"]) + print("Save synthesized audio successfully on %s." % (response_dict['result']['save_path'])) + print("Audio duration: %f s." %(response_dict['result']['duration'])) ``` 输出: @@ -181,13 +188,63 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee {'description': 'success.'} Save synthesized audio successfully on ./output.wav. Audio duration: 3.612500 s. - Response time: 0.388317 s. ``` + ### 5. CLS 客户端使用方法 + **注意:** 初次使用客户端时响应时间会略长 + - 命令行 (推荐使用) + ``` + paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input ./zh.wav + ``` + + 使用帮助: + + ```bash + paddlespeech_client cls --help + ``` + 参数: + - `server_ip`: 服务端ip地址,默认: 127.0.0.1。 + - `port`: 服务端口,默认: 8090。 + - `input`(必须输入): 用于分类的音频文件。 + - `topk`: 分类结果的topk。 + + 输出: + ```bash + [2022-03-09 20:44:39,974] [ INFO] - {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'topk': 1, 'results': [{'class_name': 'Speech', 'prob': 0.9027184844017029}]}} + [2022-03-09 20:44:39,975] [ INFO] - Response time 0.104360 s. + + + ``` + +- Python API + ```python + from paddlespeech.server.bin.paddlespeech_client import CLSClientExecutor + import json + + clsclient_executor = CLSClientExecutor() + res = clsclient_executor( + input="./zh.wav", + server_ip="127.0.0.1", + port=8090, + topk=1) + print(res.json()) + + ``` + + 输出: + ```bash + {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'topk': 1, 'results': [{'class_name': 'Speech', 'prob': 0.9027184844017029}]}} + + ``` + + ## 服务支持的模型 ### ASR支持的模型 通过 `paddlespeech_server stats --task asr` 获取ASR服务支持的所有模型,其中静态模型可用于 paddle inference 推理。 ### TTS支持的模型 通过 `paddlespeech_server stats --task tts` 获取TTS服务支持的所有模型,其中静态模型可用于 paddle inference 推理。 + +### CLS支持的模型 +通过 `paddlespeech_server stats --task cls` 获取CLS服务支持的所有模型,其中静态模型可用于 paddle inference 推理。 diff --git a/demos/speech_server/cls_client.sh b/demos/speech_server/cls_client.sh new file mode 100644 index 00000000000..5797aa204f6 --- /dev/null +++ b/demos/speech_server/cls_client.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav +paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input ./zh.wav --topk 1 diff --git a/demos/speech_server/conf/application.yaml b/demos/speech_server/conf/application.yaml index 6048450b7ba..2b1a0599808 100644 --- a/demos/speech_server/conf/application.yaml +++ b/demos/speech_server/conf/application.yaml @@ -9,12 +9,14 @@ port: 8090 # The task format in the engin_list is: _ # task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference'] -engine_list: ['asr_python', 'tts_python'] +engine_list: ['asr_python', 'tts_python', 'cls_python'] ################################################################################# # ENGINE CONFIG # ################################################################################# + +################################### ASR ######################################### ################### speech task: asr; engine_type: python ####################### asr_python: model: 'conformer_wenetspeech' @@ -46,6 +48,7 @@ asr_inference: summary: True # False -> do not show predictor config +################################### TTS ######################################### ################### speech task: tts; engine_type: python ####################### tts_python: # am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc', @@ -105,3 +108,30 @@ tts_inference: # others lang: 'zh' + +################################### CLS ######################################### +################### speech task: cls; engine_type: python ####################### +cls_python: + # model choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6'] + model: 'panns_cnn14' + cfg_path: # [optional] Config of cls task. + ckpt_path: # [optional] Checkpoint file of model. + label_file: # [optional] Label file of cls task. + device: # set 'gpu:id' or 'cpu' + + +################### speech task: cls; engine_type: inference ####################### +cls_inference: + # model_type choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6'] + model_type: 'panns_cnn14' + cfg_path: + model_path: # the pdmodel file of am static model [optional] + params_path: # the pdiparams file of am static model [optional] + label_file: # [optional] Label file of cls task. + + predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + diff --git a/docs/source/reference.md b/docs/source/reference.md index a8327e92e9d..f1a02d20009 100644 --- a/docs/source/reference.md +++ b/docs/source/reference.md @@ -35,3 +35,7 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks * [librosa](https://github.com/librosa/librosa/blob/main/LICENSE.md) - ISC License - Audio feature + +* [ThreadPool](https://github.com/progschj/ThreadPool/blob/master/COPYING) +- zlib License +- ThreadPool diff --git a/docs/source/released_model.md b/docs/source/released_model.md index 8f855f7cf1e..62986da03d1 100644 --- a/docs/source/released_model.md +++ b/docs/source/released_model.md @@ -49,17 +49,19 @@ Model Type | Dataset| Example Link | Pretrained Models| Static Models|Size (stat WaveFlow| LJSpeech |[waveflow-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0)|[waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip)||| Parallel WaveGAN| CSMSC |[PWGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1)|[pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip)|[pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip)|5.1MB| Parallel WaveGAN| LJSpeech |[PWGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc1)|[pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip)||| -Parallel WaveGAN|AISHELL-3 |[PWGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1)|[pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)||| +Parallel WaveGAN| AISHELL-3 |[PWGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1)|[pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)||| Parallel WaveGAN| VCTK |[PWGAN-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc1)|[pwg_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.5.zip)||| |Multi Band MelGAN | CSMSC |[MB MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc3) | [mb_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip)
[mb_melgan_baker_finetune_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip)|[mb_melgan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip) |8.2MB| Style MelGAN | CSMSC |[Style MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc4)|[style_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip)| | | HiFiGAN | CSMSC |[HiFiGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc5)|[hifigan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip)|[hifigan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip)|50MB| +HiFiGAN | AISHELL-3 |[HiFiGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc5)|[hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip)||| +HiFiGAN | VCTK |[HiFiGAN-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc5)|[hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip)||| WaveRNN | CSMSC |[WaveRNN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc6)|[wavernn_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip)|[wavernn_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_static_0.2.0.zip)|18MB| ### Voice Cloning Model Type | Dataset| Example Link | Pretrained Models -:-------------:| :------------:| :-----: | :-----: +:-------------:| :------------:| :-----: | :-----: | GE2E| AISHELL-3, etc. |[ge2e](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/ge2e)|[ge2e_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ge2e/ge2e_ckpt_0.3.zip) GE2E + Tactron2| AISHELL-3 |[ge2e-tactron2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/vc0)|[tacotron2_aishell3_ckpt_vc0_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_aishell3_ckpt_vc0_0.2.0.zip) GE2E + FastSpeech2 | AISHELL-3 |[ge2e-fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/vc1)|[fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip) @@ -67,9 +69,9 @@ GE2E + FastSpeech2 | AISHELL-3 |[ge2e-fastspeech2-aishell3](https://github.com/ ## Audio Classification Models -Model Type | Dataset| Example Link | Pretrained Models -:-------------:| :------------:| :-----: | :-----: -PANN | Audioset| [audioset_tagging_cnn](https://github.com/qiuqiangkong/audioset_tagging_cnn) | [panns_cnn6.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn6.pdparams), [panns_cnn10.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn10.pdparams), [panns_cnn14.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn14.pdparams) +Model Type | Dataset| Example Link | Pretrained Models | Static Models +:-------------:| :------------:| :-----: | :-----: | :-----: +PANN | Audioset| [audioset_tagging_cnn](https://github.com/qiuqiangkong/audioset_tagging_cnn) | [panns_cnn6.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn6.pdparams), [panns_cnn10.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn10.pdparams), [panns_cnn14.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn14.pdparams) | [panns_cnn6_static.tar.gz](https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz)(18M), [panns_cnn10_static.tar.gz](https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz)(19M), [panns_cnn14_static.tar.gz](https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz)(289M) PANN | ESC-50 |[pann-esc50](../../examples/esc50/cls0)|[esc50_cnn6.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn6.tar.gz), [esc50_cnn10.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn10.tar.gz), [esc50_cnn14.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn14.tar.gz) ## Punctuation Restoration Models diff --git a/examples/aishell3/tts3/local/synthesize.sh b/examples/aishell3/tts3/local/synthesize.sh index b1fc96a2d67..d3978833faa 100755 --- a/examples/aishell3/tts3/local/synthesize.sh +++ b/examples/aishell3/tts3/local/synthesize.sh @@ -4,18 +4,44 @@ config_path=$1 train_output_path=$2 ckpt_name=$3 -FLAGS_allocator_strategy=naive_best_fit \ -FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/../synthesize.py \ - --am=fastspeech2_aishell3 \ - --am_config=${config_path} \ - --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ - --voc=pwgan_aishell3 \ - --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \ - --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ - --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ - --test_metadata=dump/test/norm/metadata.jsonl \ - --output_dir=${train_output_path}/test \ - --phones_dict=dump/phone_id_map.txt \ - --speaker_dict=dump/speaker_id_map.txt +stage=0 +stop_stage=0 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --am=fastspeech2_aishell3 \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_aishell3 \ + --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt +fi + +# hifigan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --am=fastspeech2_aishell3 \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=hifigan_aishell3 \ + --voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \ + --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pd \ + --voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt +fi + diff --git a/examples/aishell3/tts3/local/synthesize_e2e.sh b/examples/aishell3/tts3/local/synthesize_e2e.sh index 60e1a5cee19..ff3608be7ae 100755 --- a/examples/aishell3/tts3/local/synthesize_e2e.sh +++ b/examples/aishell3/tts3/local/synthesize_e2e.sh @@ -4,21 +4,50 @@ config_path=$1 train_output_path=$2 ckpt_name=$3 -FLAGS_allocator_strategy=naive_best_fit \ -FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/../synthesize_e2e.py \ - --am=fastspeech2_aishell3 \ - --am_config=${config_path} \ - --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ - --voc=pwgan_aishell3 \ - --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \ - --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ - --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ - --lang=zh \ - --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/test_e2e \ - --phones_dict=dump/phone_id_map.txt \ - --speaker_dict=dump/speaker_id_map.txt \ - --spk_id=0 \ - --inference_dir=${train_output_path}/inference +stage=0 +stop_stage=0 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --am=fastspeech2_aishell3 \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_aishell3 \ + --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt \ + --spk_id=0 \ + --inference_dir=${train_output_path}/inference +fi + +# hifigan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "in hifigan syn_e2e" + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --am=fastspeech2_aishell3 \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=fastspeech2_nosil_aishell3_ckpt_0.4/speech_stats.npy \ + --voc=hifigan_aishell3 \ + --voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \ + --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt \ + --speaker_dict=fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt \ + --spk_id=0 \ + --inference_dir=${train_output_path}/inference + fi diff --git a/examples/aishell3/vc0/local/preprocess.sh b/examples/aishell3/vc0/local/preprocess.sh index 069cf94c4ee..e458c7063b2 100755 --- a/examples/aishell3/vc0/local/preprocess.sh +++ b/examples/aishell3/vc0/local/preprocess.sh @@ -1,6 +1,6 @@ #!/bin/bash -stage=3 +stage=0 stop_stage=100 config_path=$1 diff --git a/examples/aishell3/voc1/run.sh b/examples/aishell3/voc1/run.sh index 4f426ea02e1..cab1ac38b1a 100755 --- a/examples/aishell3/voc1/run.sh +++ b/examples/aishell3/voc1/run.sh @@ -3,7 +3,7 @@ set -e source path.sh -gpus=0 +gpus=0,1 stage=0 stop_stage=100 diff --git a/examples/aishell3/voc5/README.md b/examples/aishell3/voc5/README.md new file mode 100644 index 00000000000..ebe2530beec --- /dev/null +++ b/examples/aishell3/voc5/README.md @@ -0,0 +1,156 @@ +# HiFiGAN with AISHELL-3 +This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [AISHELL-3](http://www.aishelltech.com/aishell_3). + +AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus that could be used to train multi-speaker Text-to-Speech (TTS) systems. +## Dataset +### Download and Extract +Download AISHELL-3. +```bash +wget https://www.openslr.org/resources/93/data_aishell3.tgz +``` +Extract AISHELL-3. +```bash +mkdir data_aishell3 +tar zxvf data_aishell3.tgz -C data_aishell3 +``` +### Get MFA Result and Extract +We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2. +You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/data_aishell3`. +Assume the path to the MFA result of AISHELL-3 is `./aishell3_alignment_tone`. +Run the command below to +1. **source path**. +2. preprocess the dataset. +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. +```bash +./run.sh +``` +You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, run the following command will only preprocess the dataset. +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Data Preprocessing +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│ ├── norm +│ └── raw +├── test +│ ├── norm +│ └── raw +└── train + ├── norm + ├── raw + └── feats_stats.npy +``` + +The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance. + +### Model Training +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +`./local/train.sh` calls `${BIN_DIR}/train.py`. +Here's the complete help message. + +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] + [--run-benchmark RUN_BENCHMARK] + [--profiler_options PROFILER_OPTIONS] + +Train a ParallelWaveGAN model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG config file to overwrite default config. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. + +benchmark: + arguments related to benchmark. + + --batch-size BATCH_SIZE + batch size. + --max-iter MAX_ITER train max steps. + --run-benchmark RUN_BENCHMARK + runing benchmark or not, if True, use the --batch-size + and --max-iter. + --profiler_options PROFILER_OPTIONS + The option of profiler, which should be in format + "key1=value1;key2=value2;key3=value3". +``` + +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory. +4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +### Synthesizing +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG] + [--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA] + [--output-dir OUTPUT_DIR] [--ngpu NGPU] + +Synthesize with GANVocoder. + +optional arguments: + -h, --help show this help message and exit + --generator-type GENERATOR_TYPE + type of GANVocoder, should in {pwgan, mb_melgan, + style_melgan, } now + --config CONFIG GANVocoder config file. + --checkpoint CHECKPOINT + snapshot to load. + --test-metadata TEST_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. +``` + +1. `--config` config file. You should use the same config with which the model is trained. +2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory. +3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory. +4. `--output-dir` is the directory to save the synthesized audio files. +5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. +## Pretrained Models +The pretrained model can be downloaded here [hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip). + + +Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss +:-------------:| :------------:| :-----: | :-----: | :--------: +default| 1(gpu) x 2500000|24.060|0.1068|7.499 + +HiFiGAN checkpoint contains files listed below. + +```text +hifigan_aishell3_ckpt_0.2.0 +├── default.yaml # default config used to train hifigan +├── feats_stats.npy # statistics used to normalize spectrogram when training hifigan +└── snapshot_iter_2500000.pdz # generator parameters of hifigan +``` + +## Acknowledgement +We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN. diff --git a/examples/aishell3/voc5/conf/default.yaml b/examples/aishell3/voc5/conf/default.yaml new file mode 100644 index 00000000000..728a9036909 --- /dev/null +++ b/examples/aishell3/voc5/conf/default.yaml @@ -0,0 +1,168 @@ +# This is the configuration file for AISHELL-3 dataset. +# This configuration is based on HiFiGAN V1, which is +# an official configuration. But I found that the optimizer +# setting does not work well with my implementation. +# So I changed optimizer settings as follows: +# - AdamW -> Adam +# - betas: [0.8, 0.99] -> betas: [0.5, 0.9] +# - Scheduler: ExponentialLR -> MultiStepLR +# To match the shift size difference, the upsample scales +# is also modified from the original 256 shift setting. +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 1 # Number of output channels. + channels: 512 # Number of initial channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + upsample_scales: [5, 5, 4, 3] # Upsampling scales. + upsample_kernel_sizes: [10, 10, 8, 6] # Kernel size for upsampling layers. + resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks. + resblock_dilations: # Dilations for residual blocks. + - [1, 3, 5] + - [1, 3, 5] + - [1, 3, 5] + use_additional_convs: True # Whether to use additional conv layer in residual blocks. + bias: True # Whether to use bias parameter in conv. + nonlinear_activation: "leakyrelu" # Nonlinear activation type. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + scales: 3 # Number of multi-scale discriminator. + scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator. + scale_downsample_pooling_params: + kernel_size: 4 # Pooling kernel size. + stride: 2 # Pooling stride. + padding: 2 # Padding size. + scale_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. + channels: 128 # Initial number of channels. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + max_groups: 16 # Maximum number of groups in downsampling conv layers. + bias: True + downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: + negative_slope: 0.1 + follow_official_norm: True # Whether to follow the official norm setting. + periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. + period_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [5, 3] # List of kernel sizes. + channels: 32 # Initial number of channels. + downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + bias: True # Whether to use bias parameter in conv layer." + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + use_spectral_norm: False # Whether to apply spectral normalization. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: False # Whether to use multi-resolution STFT loss. +use_mel_loss: True # Whether to use Mel-spectrogram loss. +mel_loss_params: + fs: 24000 + fft_size: 2048 + hop_size: 300 + win_length: 1200 + window: "hann" + num_mels: 80 + fmin: 0 + fmax: 12000 + log_base: null +generator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +discriminator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +use_feat_match_loss: True +feat_match_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. + average_by_layers: False # Whether to average loss by #layers in each discriminator. + include_final_outputs: False # Whether to include final outputs in feat match loss calculation. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_aux: 45.0 # Loss balancing coefficient for STFT loss. +lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss. +lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 16 # Batch size. +batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 2.0e-4 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +generator_grad_norm: -1 # Generator's gradient norm. +discriminator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 2.0e-4 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_grad_norm: -1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +generator_train_start_steps: 1 # Number of steps to start to train discriminator. +discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. +train_max_steps: 2500000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/examples/aishell3/voc5/local/preprocess.sh b/examples/aishell3/voc5/local/preprocess.sh new file mode 100755 index 00000000000..44cc3dbe460 --- /dev/null +++ b/examples/aishell3/voc5/local/preprocess.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./aishell3_alignment_tone \ + --output=durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/../preprocess.py \ + --rootdir=~/datasets/data_aishell3/ \ + --dataset=aishell3 \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --cut-sil=True \ + --num-cpu=20 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="feats" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize, dev and test should use train's stats + echo "Normalize ..." + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --stats=dump/train/feats_stats.npy + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --stats=dump/train/feats_stats.npy + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --stats=dump/train/feats_stats.npy +fi diff --git a/examples/aishell3/voc5/local/synthesize.sh b/examples/aishell3/voc5/local/synthesize.sh new file mode 100755 index 00000000000..6478961756f --- /dev/null +++ b/examples/aishell3/voc5/local/synthesize.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/../synthesize.py \ + --config=${config_path} \ + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --test-metadata=dump/test/norm/metadata.jsonl \ + --output-dir=${train_output_path}/test \ + --generator-type=hifigan diff --git a/examples/aishell3/voc5/local/train.sh b/examples/aishell3/voc5/local/train.sh new file mode 100755 index 00000000000..9695631ef02 --- /dev/null +++ b/examples/aishell3/voc5/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +FLAGS_cudnn_exhaustive_search=true \ +FLAGS_conv_workspace_size_limit=4000 \ +python ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=1 diff --git a/examples/aishell3/voc5/path.sh b/examples/aishell3/voc5/path.sh new file mode 100755 index 00000000000..7451b3218e2 --- /dev/null +++ b/examples/aishell3/voc5/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=hifigan +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL} diff --git a/examples/aishell3/voc5/run.sh b/examples/aishell3/voc5/run.sh new file mode 100755 index 00000000000..4f426ea02e1 --- /dev/null +++ b/examples/aishell3/voc5/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_5000.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/examples/ami/sd0/local/ami_prepare.py b/examples/ami/sd0/local/ami_prepare.py index d03810a777a..01582dbdd33 100644 --- a/examples/ami/sd0/local/ami_prepare.py +++ b/examples/ami/sd0/local/ami_prepare.py @@ -17,11 +17,8 @@ Download: http://groups.inf.ed.ac.uk/ami/download/ Prepares metadata files (JSON) from manual annotations "segments/" using RTTM format (Oracle VAD). - -Authors - * qingenz123@126.com (Qingen ZHAO) 2022 - """ + import argparse import glob import json diff --git a/examples/ami/sd0/local/ami_splits.py b/examples/ami/sd0/local/ami_splits.py index 010638a3969..a8bc5dc8485 100644 --- a/examples/ami/sd0/local/ami_splits.py +++ b/examples/ami/sd0/local/ami_splits.py @@ -15,10 +15,6 @@ AMI corpus contained 100 hours of meeting recording. This script returns the standard train, dev and eval split for AMI corpus. For more information on dataset please refer to http://groups.inf.ed.ac.uk/ami/corpus/datasets.shtml - -Authors - * qingenz123@126.com (Qingen ZHAO) 2022 - """ ALLOWED_OPTIONS = ["scenario_only", "full_corpus", "full_corpus_asr"] diff --git a/examples/ami/sd0/local/dataio.py b/examples/ami/sd0/local/dataio.py index f7fe881573c..4ff76bd5bd1 100644 --- a/examples/ami/sd0/local/dataio.py +++ b/examples/ami/sd0/local/dataio.py @@ -13,10 +13,6 @@ # limitations under the License. """ Data reading and writing. - -Authors - * qingenz123@126.com (Qingen ZHAO) 2022 - """ import os import pickle diff --git a/examples/csmsc/tts0/local/synthesize_e2e.sh b/examples/csmsc/tts0/local/synthesize_e2e.sh index f7675873386..4c3b08dc1f5 100755 --- a/examples/csmsc/tts0/local/synthesize_e2e.sh +++ b/examples/csmsc/tts0/local/synthesize_e2e.sh @@ -7,7 +7,7 @@ ckpt_name=$3 stage=0 stop_stage=0 -# TODO: tacotron2 动转静的结果没有静态图的响亮, 可能还是 decode 的时候某个函数动静不对齐 +# TODO: tacotron2 动转静的结果没有动态图的响亮, 可能还是 decode 的时候某个函数动静不对齐 # pwgan if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then FLAGS_allocator_strategy=naive_best_fit \ diff --git a/examples/csmsc/tts2/local/synthesize.sh b/examples/csmsc/tts2/local/synthesize.sh index 37b2981831e..b8982a16da2 100755 --- a/examples/csmsc/tts2/local/synthesize.sh +++ b/examples/csmsc/tts2/local/synthesize.sh @@ -14,7 +14,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --am=speedyspeech_csmsc \ --am_config=${config_path} \ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ + --am_stat=dump/train/feats_stats.npy \ --voc=pwgan_csmsc \ --voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \ --voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \ @@ -34,7 +34,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --am=speedyspeech_csmsc \ --am_config=${config_path} \ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ + --am_stat=dump/train/feats_stats.npy \ --voc=mb_melgan_csmsc \ --voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \ --voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\ @@ -53,7 +53,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then --am=speedyspeech_csmsc \ --am_config=${config_path} \ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ + --am_stat=dump/train/feats_stats.npy \ --voc=style_melgan_csmsc \ --voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \ --voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \ @@ -73,7 +73,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --am=speedyspeech_csmsc \ --am_config=${config_path} \ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ + --am_stat=dump/train/feats_stats.npy \ --voc=hifigan_csmsc \ --voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \ --voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \ @@ -93,7 +93,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --am=speedyspeech_csmsc \ --am_config=${config_path} \ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ + --am_stat=dump/train/feats_stats.npy \ --voc=wavernn_csmsc \ --voc_config=wavernn_csmsc_ckpt_0.2.0/default.yaml \ --voc_ckpt=wavernn_csmsc_ckpt_0.2.0/snapshot_iter_400000.pdz \ diff --git a/examples/ljspeech/voc5/README.md b/examples/ljspeech/voc5/README.md new file mode 100644 index 00000000000..21082942845 --- /dev/null +++ b/examples/ljspeech/voc5/README.md @@ -0,0 +1,133 @@ +# HiFiGAN with the LJSpeech-1.1 +This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/). +## Dataset +### Download and Extract +Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/). +### Get MFA Result and Extract +We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. +You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/LJSpeech-1.1`. +Assume the path to the MFA result of LJSpeech-1.1 is `./ljspeech_alignment`. +Run the command below to +1. **source path**. +2. preprocess the dataset. +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. +```bash +./run.sh +``` +You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset. +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Data Preprocessing +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│ ├── norm +│ └── raw +├── test +│ ├── norm +│ └── raw +└── train + ├── norm + ├── raw + └── feats_stats.npy +``` + +The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance. + +### Model Training +`./local/train.sh` calls `${BIN_DIR}/train.py`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +Here's the complete help message. + +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] + [--run-benchmark RUN_BENCHMARK] + [--profiler_options PROFILER_OPTIONS] + +Train a ParallelWaveGAN model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG config file to overwrite default config. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. + +benchmark: + arguments related to benchmark. + + --batch-size BATCH_SIZE + batch size. + --max-iter MAX_ITER train max steps. + --run-benchmark RUN_BENCHMARK + runing benchmark or not, if True, use the --batch-size + and --max-iter. + --profiler_options PROFILER_OPTIONS + The option of profiler, which should be in format + "key1=value1;key2=value2;key3=value3". +``` + +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory. +4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +### Synthesizing +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG] + [--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA] + [--output-dir OUTPUT_DIR] [--ngpu NGPU] + +Synthesize with GANVocoder. + +optional arguments: + -h, --help show this help message and exit + --generator-type GENERATOR_TYPE + type of GANVocoder, should in {pwgan, mb_melgan, + style_melgan, } now + --config CONFIG GANVocoder config file. + --checkpoint CHECKPOINT + snapshot to load. + --test-metadata TEST_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. +``` + +1. `--config` parallel wavegan config file. You should use the same config with which the model is trained. +2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory. +3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory. +4. `--output-dir` is the directory to save the synthesized audio files. +5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +## Pretrained Model + + +## Acknowledgement +We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN. diff --git a/examples/ljspeech/voc5/conf/default.yaml b/examples/ljspeech/voc5/conf/default.yaml new file mode 100644 index 00000000000..97c51220409 --- /dev/null +++ b/examples/ljspeech/voc5/conf/default.yaml @@ -0,0 +1,167 @@ +# This is the configuration file for LJSpeech dataset. +# This configuration is based on HiFiGAN V1, which is an official configuration. +# But I found that the optimizer setting does not work well with my implementation. +# So I changed optimizer settings as follows: +# - AdamW -> Adam +# - betas: [0.8, 0.99] -> betas: [0.5, 0.9] +# - Scheduler: ExponentialLR -> MultiStepLR +# To match the shift size difference, the upsample scales is also modified from the original 256 shift setting. + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 22050 # Sampling rate. +n_fft: 1024 # FFT size (samples). +n_shift: 256 # Hop size (samples). 11.6ms +win_length: null # Window length (samples). + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 1 # Number of output channels. + channels: 512 # Number of initial channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + upsample_scales: [8, 8, 2, 2] # Upsampling scales. + upsample_kernel_sizes: [16, 16, 4, 4] # Kernel size for upsampling layers. + resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks. + resblock_dilations: # Dilations for residual blocks. + - [1, 3, 5] + - [1, 3, 5] + - [1, 3, 5] + use_additional_convs: True # Whether to use additional conv layer in residual blocks. + bias: True # Whether to use bias parameter in conv. + nonlinear_activation: "leakyrelu" # Nonlinear activation type. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + scales: 3 # Number of multi-scale discriminator. + scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator. + scale_downsample_pooling_params: + kernel_size: 4 # Pooling kernel size. + stride: 2 # Pooling stride. + padding: 2 # Padding size. + scale_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. + channels: 128 # Initial number of channels. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + max_groups: 16 # Maximum number of groups in downsampling conv layers. + bias: True + downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: + negative_slope: 0.1 + follow_official_norm: True # Whether to follow the official norm setting. + periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. + period_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [5, 3] # List of kernel sizes. + channels: 32 # Initial number of channels. + downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + bias: True # Whether to use bias parameter in conv layer." + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + use_spectral_norm: False # Whether to apply spectral normalization. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: False # Whether to use multi-resolution STFT loss. +use_mel_loss: True # Whether to use Mel-spectrogram loss. +mel_loss_params: + fs: 22050 + fft_size: 1024 + hop_size: 256 + win_length: null + window: "hann" + num_mels: 80 + fmin: 0 + fmax: 11025 + log_base: null +generator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +discriminator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +use_feat_match_loss: True +feat_match_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. + average_by_layers: False # Whether to average loss by #layers in each discriminator. + include_final_outputs: False # Whether to include final outputs in feat match loss calculation. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_aux: 45.0 # Loss balancing coefficient for STFT loss. +lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss. +lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 16 # Batch size. +batch_max_steps: 8192 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 2.0e-4 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +generator_grad_norm: -1 # Generator's gradient norm. +discriminator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 2.0e-4 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_grad_norm: -1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +generator_train_start_steps: 1 # Number of steps to start to train discriminator. +discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. +train_max_steps: 2500000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/examples/ljspeech/voc5/local/preprocess.sh b/examples/ljspeech/voc5/local/preprocess.sh new file mode 100755 index 00000000000..d1af60dad6a --- /dev/null +++ b/examples/ljspeech/voc5/local/preprocess.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./ljspeech_alignment \ + --output=durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/../preprocess.py \ + --rootdir=~/datasets/LJSpeech-1.1/ \ + --dataset=ljspeech \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --cut-sil=True \ + --num-cpu=20 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="feats" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize, dev and test should use train's stats + echo "Normalize ..." + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --stats=dump/train/feats_stats.npy + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --stats=dump/train/feats_stats.npy + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --stats=dump/train/feats_stats.npy +fi diff --git a/examples/ljspeech/voc5/local/synthesize.sh b/examples/ljspeech/voc5/local/synthesize.sh new file mode 100755 index 00000000000..6478961756f --- /dev/null +++ b/examples/ljspeech/voc5/local/synthesize.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/../synthesize.py \ + --config=${config_path} \ + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --test-metadata=dump/test/norm/metadata.jsonl \ + --output-dir=${train_output_path}/test \ + --generator-type=hifigan diff --git a/examples/ljspeech/voc5/local/train.sh b/examples/ljspeech/voc5/local/train.sh new file mode 100755 index 00000000000..9695631ef02 --- /dev/null +++ b/examples/ljspeech/voc5/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +FLAGS_cudnn_exhaustive_search=true \ +FLAGS_conv_workspace_size_limit=4000 \ +python ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=1 diff --git a/examples/ljspeech/voc5/path.sh b/examples/ljspeech/voc5/path.sh new file mode 100755 index 00000000000..7451b3218e2 --- /dev/null +++ b/examples/ljspeech/voc5/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=hifigan +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL} diff --git a/examples/ljspeech/voc5/run.sh b/examples/ljspeech/voc5/run.sh new file mode 100755 index 00000000000..cab1ac38b1a --- /dev/null +++ b/examples/ljspeech/voc5/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_5000.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/examples/vctk/tts3/local/synthesize.sh b/examples/vctk/tts3/local/synthesize.sh index 8381af464e6..9e03f9b8a47 100755 --- a/examples/vctk/tts3/local/synthesize.sh +++ b/examples/vctk/tts3/local/synthesize.sh @@ -4,18 +4,43 @@ config_path=$1 train_output_path=$2 ckpt_name=$3 -FLAGS_allocator_strategy=naive_best_fit \ -FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/../synthesize.py \ - --am=fastspeech2_vctk \ - --am_config=${config_path} \ - --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ - --voc=pwgan_vctk \ - --voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \ - --voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \ - --voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \ - --test_metadata=dump/test/norm/metadata.jsonl \ - --output_dir=${train_output_path}/test \ - --phones_dict=dump/phone_id_map.txt \ - --speaker_dict=dump/speaker_id_map.txt +stage=0 +stop_stage=0 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --am=fastspeech2_vctk \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_vctk \ + --voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \ + --voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \ + --voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt +fi + +# hifigan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --am=fastspeech2_aishell3 \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=hifigan_vctk \ + --voc_config=hifigan_vctk_ckpt_0.2.0/default.yaml \ + --voc_ckpt=hifigan_vctk_ckpt_0.2.0/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_vctk_ckpt_0.2.0/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt +fi diff --git a/examples/vctk/tts3/local/synthesize_e2e.sh b/examples/vctk/tts3/local/synthesize_e2e.sh index 60d56d1c9cb..a89f42b50da 100755 --- a/examples/vctk/tts3/local/synthesize_e2e.sh +++ b/examples/vctk/tts3/local/synthesize_e2e.sh @@ -4,21 +4,49 @@ config_path=$1 train_output_path=$2 ckpt_name=$3 -FLAGS_allocator_strategy=naive_best_fit \ -FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/../synthesize_e2e.py \ - --am=fastspeech2_vctk \ - --am_config=${config_path} \ - --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ - --am_stat=dump/train/speech_stats.npy \ - --voc=pwgan_vctk \ - --voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \ - --voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \ - --voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \ - --lang=en \ - --text=${BIN_DIR}/../sentences_en.txt \ - --output_dir=${train_output_path}/test_e2e \ - --phones_dict=dump/phone_id_map.txt \ - --speaker_dict=dump/speaker_id_map.txt \ - --spk_id=0 \ - --inference_dir=${train_output_path}/inference +stage=0 +stop_stage=0 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --am=fastspeech2_vctk \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_vctk \ + --voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \ + --voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \ + --voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \ + --lang=en \ + --text=${BIN_DIR}/../sentences_en.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt \ + --spk_id=0 \ + --inference_dir=${train_output_path}/inference +fi + +# hifigan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --am=fastspeech2_vctk \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=hifigan_vctk \ + --voc_config=hifigan_vctk_ckpt_0.2.0/default.yaml \ + --voc_ckpt=hifigan_vctk_ckpt_0.2.0/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_vctk_ckpt_0.2.0/feats_stats.npy \ + --lang=en \ + --text=${BIN_DIR}/../sentences_en.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt \ + --spk_id=0 \ + --inference_dir=${train_output_path}/inference +fi diff --git a/examples/vctk/voc5/README.md b/examples/vctk/voc5/README.md new file mode 100644 index 00000000000..b4be341c0e5 --- /dev/null +++ b/examples/vctk/voc5/README.md @@ -0,0 +1,153 @@ +# HiFiGAN with VCTK +This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [VCTK](https://datashare.ed.ac.uk/handle/10283/3443). + +## Dataset +### Download and Extract +Download VCTK-0.92 from the [official website](https://datashare.ed.ac.uk/handle/10283/3443) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/VCTK-Corpus-0.92`. + +### Get MFA Result and Extract +We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio. +You can download from here [vctk_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/VCTK-Corpus-0.92/vctk_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. +ps: we remove three speakers in VCTK-0.92 (see [reorganize_vctk.py](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/examples/other/mfa/local/reorganize_vctk.py)): +1. `p315`, because of no text for it. +2. `p280` and `p362`, because no *_mic2.flac (which is better than *_mic1.flac) for them. + +## Get Started +Assume the path to the dataset is `~/datasets/VCTK-Corpus-0.92`. +Assume the path to the MFA result of VCTK is `./vctk_alignment`. +Run the command below to +1. **source path**. +2. preprocess the dataset. +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. +```bash +./run.sh +``` +You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset. +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Data Preprocessing +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│ ├── norm +│ └── raw +├── test +│ ├── norm +│ └── raw +└── train + ├── norm + ├── raw + └── feats_stats.npy +``` + +The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance. + +### Model Training +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +`./local/train.sh` calls `${BIN_DIR}/train.py`. +Here's the complete help message. + +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER] + [--run-benchmark RUN_BENCHMARK] + [--profiler_options PROFILER_OPTIONS] + +Train a ParallelWaveGAN model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG config file to overwrite default config. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. + +benchmark: + arguments related to benchmark. + + --batch-size BATCH_SIZE + batch size. + --max-iter MAX_ITER train max steps. + --run-benchmark RUN_BENCHMARK + runing benchmark or not, if True, use the --batch-size + and --max-iter. + --profiler_options PROFILER_OPTIONS + The option of profiler, which should be in format + "key1=value1;key2=value2;key3=value3". +``` + +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory. +4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +### Synthesizing +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG] + [--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA] + [--output-dir OUTPUT_DIR] [--ngpu NGPU] + +Synthesize with GANVocoder. + +optional arguments: + -h, --help show this help message and exit + --generator-type GENERATOR_TYPE + type of GANVocoder, should in {pwgan, mb_melgan, + style_melgan, } now + --config CONFIG GANVocoder config file. + --checkpoint CHECKPOINT + snapshot to load. + --test-metadata TEST_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. +``` + + +1. `--config` config file. You should use the same config with which the model is trained. +2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory. +3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory. +4. `--output-dir` is the directory to save the synthesized audio files. +5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +## Pretrained Model +The pretrained model can be downloaded here [hifigan_vctk_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip). + + +Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss +:-------------:| :------------:| :-----: | :-----: | :--------: +default| 1(gpu) x 2500000|58.092|0.1234|24.384 + +HiFiGAN checkpoint contains files listed below. + +```text +hifigan_vctk_ckpt_0.2.0 +├── default.yaml # default config used to train hifigan +├── feats_stats.npy # statistics used to normalize spectrogram when training hifigan +└── snapshot_iter_2500000.pdz # generator parameters of hifigan +``` + +## Acknowledgement +We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN. diff --git a/examples/vctk/voc5/conf/default.yaml b/examples/vctk/voc5/conf/default.yaml new file mode 100644 index 00000000000..6361e01b221 --- /dev/null +++ b/examples/vctk/voc5/conf/default.yaml @@ -0,0 +1,168 @@ +# This is the configuration file for VCTK dataset. +# This configuration is based on HiFiGAN V1, which is +# an official configuration. But I found that the optimizer +# setting does not work well with my implementation. +# So I changed optimizer settings as follows: +# - AdamW -> Adam +# - betas: [0.8, 0.99] -> betas: [0.5, 0.9] +# - Scheduler: ExponentialLR -> MultiStepLR +# To match the shift size difference, the upsample scales +# is also modified from the original 256 shift setting. +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 1 # Number of output channels. + channels: 512 # Number of initial channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + upsample_scales: [5, 5, 4, 3] # Upsampling scales. + upsample_kernel_sizes: [10, 10, 8, 6] # Kernel size for upsampling layers. + resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks. + resblock_dilations: # Dilations for residual blocks. + - [1, 3, 5] + - [1, 3, 5] + - [1, 3, 5] + use_additional_convs: True # Whether to use additional conv layer in residual blocks. + bias: True # Whether to use bias parameter in conv. + nonlinear_activation: "leakyrelu" # Nonlinear activation type. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + scales: 3 # Number of multi-scale discriminator. + scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator. + scale_downsample_pooling_params: + kernel_size: 4 # Pooling kernel size. + stride: 2 # Pooling stride. + padding: 2 # Padding size. + scale_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. + channels: 128 # Initial number of channels. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + max_groups: 16 # Maximum number of groups in downsampling conv layers. + bias: True + downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: + negative_slope: 0.1 + follow_official_norm: True # Whether to follow the official norm setting. + periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. + period_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [5, 3] # List of kernel sizes. + channels: 32 # Initial number of channels. + downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + bias: True # Whether to use bias parameter in conv layer." + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + use_spectral_norm: False # Whether to apply spectral normalization. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: False # Whether to use multi-resolution STFT loss. +use_mel_loss: True # Whether to use Mel-spectrogram loss. +mel_loss_params: + fs: 24000 + fft_size: 2048 + hop_size: 300 + win_length: 1200 + window: "hann" + num_mels: 80 + fmin: 0 + fmax: 12000 + log_base: null +generator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +discriminator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +use_feat_match_loss: True +feat_match_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. + average_by_layers: False # Whether to average loss by #layers in each discriminator. + include_final_outputs: False # Whether to include final outputs in feat match loss calculation. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_aux: 45.0 # Loss balancing coefficient for STFT loss. +lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss. +lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 16 # Batch size. +batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 2.0e-4 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +generator_grad_norm: -1 # Generator's gradient norm. +discriminator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 2.0e-4 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_grad_norm: -1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +generator_train_start_steps: 1 # Number of steps to start to train discriminator. +discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. +train_max_steps: 2500000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/examples/vctk/voc5/local/preprocess.sh b/examples/vctk/voc5/local/preprocess.sh new file mode 100755 index 00000000000..88a478cd537 --- /dev/null +++ b/examples/vctk/voc5/local/preprocess.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./vctk_alignment \ + --output=durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/../preprocess.py \ + --rootdir=~/datasets/VCTK-Corpus-0.92/ \ + --dataset=vctk \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --cut-sil=True \ + --num-cpu=20 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="feats" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize, dev and test should use train's stats + echo "Normalize ..." + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --stats=dump/train/feats_stats.npy + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --stats=dump/train/feats_stats.npy + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --stats=dump/train/feats_stats.npy +fi diff --git a/examples/vctk/voc5/local/synthesize.sh b/examples/vctk/voc5/local/synthesize.sh new file mode 100755 index 00000000000..6478961756f --- /dev/null +++ b/examples/vctk/voc5/local/synthesize.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/../synthesize.py \ + --config=${config_path} \ + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --test-metadata=dump/test/norm/metadata.jsonl \ + --output-dir=${train_output_path}/test \ + --generator-type=hifigan diff --git a/examples/vctk/voc5/local/train.sh b/examples/vctk/voc5/local/train.sh new file mode 100755 index 00000000000..9695631ef02 --- /dev/null +++ b/examples/vctk/voc5/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +FLAGS_cudnn_exhaustive_search=true \ +FLAGS_conv_workspace_size_limit=4000 \ +python ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=1 diff --git a/examples/vctk/voc5/path.sh b/examples/vctk/voc5/path.sh new file mode 100755 index 00000000000..7451b3218e2 --- /dev/null +++ b/examples/vctk/voc5/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=hifigan +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL} diff --git a/examples/vctk/voc5/run.sh b/examples/vctk/voc5/run.sh new file mode 100755 index 00000000000..4f426ea02e1 --- /dev/null +++ b/examples/vctk/voc5/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_5000.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/paddleaudio/CHANGELOG.md b/paddleaudio/CHANGELOG.md index 91b0fef08e6..925d7769684 100644 --- a/paddleaudio/CHANGELOG.md +++ b/paddleaudio/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +Date: 2022-3-15, Author: Xiaojie Chen. + - kaldi and librosa mfcc, fbank, spectrogram. + - unit test and benchmark. + Date: 2022-2-25, Author: Hui Zhang. - Refactor architecture. - - dtw distance and mcd style dtw + - dtw distance and mcd style dtw. diff --git a/paddleaudio/paddleaudio/backends/soundfile_backend.py b/paddleaudio/paddleaudio/backends/soundfile_backend.py index 2b920284a6c..c1155654f2f 100644 --- a/paddleaudio/paddleaudio/backends/soundfile_backend.py +++ b/paddleaudio/paddleaudio/backends/soundfile_backend.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import warnings from typing import Optional from typing import Tuple @@ -19,7 +20,6 @@ import numpy as np import resampy import soundfile as sf -from numpy import ndarray as array from scipy.io import wavfile from ..utils import ParameterError @@ -38,13 +38,21 @@ EPS = 1e-8 -def resample(y: array, src_sr: int, target_sr: int, - mode: str='kaiser_fast') -> array: - """ Audio resampling - This function is the same as using resampy.resample(). - Notes: - The default mode is kaiser_fast. For better audio quality, use mode = 'kaiser_fast' - """ +def resample(y: np.ndarray, + src_sr: int, + target_sr: int, + mode: str='kaiser_fast') -> np.ndarray: + """Audio resampling. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + src_sr (int): Source sample rate. + target_sr (int): Target sample rate. + mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'. + + Returns: + np.ndarray: `y` resampled to `target_sr` + """ if mode == 'kaiser_best': warnings.warn( @@ -53,7 +61,7 @@ def resample(y: array, src_sr: int, target_sr: int, if not isinstance(y, np.ndarray): raise ParameterError( - 'Only support numpy array, but received y in {type(y)}') + 'Only support numpy np.ndarray, but received y in {type(y)}') if mode not in RESAMPLE_MODES: raise ParameterError(f'resample mode must in {RESAMPLE_MODES}') @@ -61,9 +69,17 @@ def resample(y: array, src_sr: int, target_sr: int, return resampy.resample(y, src_sr, target_sr, filter=mode) -def to_mono(y: array, merge_type: str='average') -> array: - """ convert sterior audio to mono +def to_mono(y: np.ndarray, merge_type: str='average') -> np.ndarray: + """Convert sterior audio to mono. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + merge_type (str, optional): Merge type to generate mono waveform. Defaults to 'average'. + + Returns: + np.ndarray: `y` with mono channel. """ + if merge_type not in MERGE_TYPES: raise ParameterError( f'Unsupported merge type {merge_type}, available types are {MERGE_TYPES}' @@ -101,18 +117,34 @@ def to_mono(y: array, merge_type: str='average') -> array: return y_out -def _safe_cast(y: array, dtype: Union[type, str]) -> array: - """ data type casting in a safe way, i.e., prevent overflow or underflow - This function is used internally. +def _safe_cast(y: np.ndarray, dtype: Union[type, str]) -> np.ndarray: + """Data type casting in a safe way, i.e., prevent overflow or underflow. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + dtype (Union[type, str]): Data type of waveform. + + Returns: + np.ndarray: `y` after safe casting. """ - return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype) + if 'float' in str(y.dtype): + return np.clip(y, np.finfo(dtype).min, + np.finfo(dtype).max).astype(dtype) + else: + return np.clip(y, np.iinfo(dtype).min, + np.iinfo(dtype).max).astype(dtype) -def depth_convert(y: array, dtype: Union[type, str], - dithering: bool=True) -> array: - """Convert audio array to target dtype safely - This function convert audio waveform to a target dtype, with addition steps of +def depth_convert(y: np.ndarray, dtype: Union[type, str]) -> np.ndarray: + """Convert audio array to target dtype safely. This function convert audio waveform to a target dtype, with addition steps of preventing overflow/underflow and preserving audio range. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + dtype (Union[type, str]): Data type of waveform. + + Returns: + np.ndarray: `y` after safe casting. """ SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64'] @@ -157,14 +189,20 @@ def depth_convert(y: array, dtype: Union[type, str], return y -def sound_file_load(file: str, +def sound_file_load(file: os.PathLike, offset: Optional[float]=None, dtype: str='int16', - duration: Optional[int]=None) -> Tuple[array, int]: - """Load audio using soundfile library - This function load audio file using libsndfile. - Reference: - http://www.mega-nerd.com/libsndfile/#Features + duration: Optional[int]=None) -> Tuple[np.ndarray, int]: + """Load audio using soundfile library. This function load audio file using libsndfile. + + Args: + file (os.PathLike): File of waveform. + offset (Optional[float], optional): Offset to the start of waveform. Defaults to None. + dtype (str, optional): Data type of waveform. Defaults to 'int16'. + duration (Optional[int], optional): Duration of waveform to read. Defaults to None. + + Returns: + Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate. """ with sf.SoundFile(file) as sf_desc: sr_native = sf_desc.samplerate @@ -179,9 +217,17 @@ def sound_file_load(file: str, return y, sf_desc.samplerate -def normalize(y: array, norm_type: str='linear', - mul_factor: float=1.0) -> array: - """ normalize an input audio with additional multiplier. +def normalize(y: np.ndarray, norm_type: str='linear', + mul_factor: float=1.0) -> np.ndarray: + """Normalize an input audio with additional multiplier. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + norm_type (str, optional): Type of normalization. Defaults to 'linear'. + mul_factor (float, optional): Scaling factor. Defaults to 1.0. + + Returns: + np.ndarray: `y` after normalization. """ if norm_type == 'linear': @@ -199,12 +245,13 @@ def normalize(y: array, norm_type: str='linear', return y -def save(y: array, sr: int, file: str) -> None: - """Save audio file to disk. - This function saves audio to disk using scipy.io.wavfile, with additional step - to convert input waveform to int16 unless it already is int16 - Notes: - It only support raw wav format. +def save(y: np.ndarray, sr: int, file: os.PathLike) -> None: + """Save audio file to disk. This function saves audio to disk using scipy.io.wavfile, with additional step to convert input waveform to int16. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + sr (int): Sample rate. + file (os.PathLike): Path of auido file to save. """ if not file.endswith('.wav'): raise ParameterError( @@ -226,7 +273,7 @@ def save(y: array, sr: int, file: str) -> None: def load( - file: str, + file: os.PathLike, sr: Optional[int]=None, mono: bool=True, merge_type: str='average', # ch0,ch1,random,average @@ -236,11 +283,24 @@ def load( offset: float=0.0, duration: Optional[int]=None, dtype: str='float32', - resample_mode: str='kaiser_fast') -> Tuple[array, int]: - """Load audio file from disk. - This function loads audio from disk using using audio beackend. - Parameters: - Notes: + resample_mode: str='kaiser_fast') -> Tuple[np.ndarray, int]: + """Load audio file from disk. This function loads audio from disk using using audio beackend. + + Args: + file (os.PathLike): Path of auido file to load. + sr (Optional[int], optional): Sample rate of loaded waveform. Defaults to None. + mono (bool, optional): Return waveform with mono channel. Defaults to True. + merge_type (str, optional): Merge type of multi-channels waveform. Defaults to 'average'. + normal (bool, optional): Waveform normalization. Defaults to True. + norm_type (str, optional): Type of normalization. Defaults to 'linear'. + norm_mul_factor (float, optional): Scaling factor. Defaults to 1.0. + offset (float, optional): Offset to the start of waveform. Defaults to 0.0. + duration (Optional[int], optional): Duration of waveform to read. Defaults to None. + dtype (str, optional): Data type of waveform. Defaults to 'float32'. + resample_mode (str, optional): The resampling filter to use. Defaults to 'kaiser_fast'. + + Returns: + Tuple[np.ndarray, int]: Waveform in ndarray and its samplerate. """ y, r = sound_file_load(file, offset=offset, dtype=dtype, duration=duration) diff --git a/paddleaudio/paddleaudio/compliance/kaldi.py b/paddleaudio/paddleaudio/compliance/kaldi.py index 8cb9b666053..538be019619 100644 --- a/paddleaudio/paddleaudio/compliance/kaldi.py +++ b/paddleaudio/paddleaudio/compliance/kaldi.py @@ -220,7 +220,7 @@ def spectrogram(waveform: Tensor, """Compute and return a spectrogram from a waveform. The output is identical to Kaldi's. Args: - waveform (Tensor): A waveform tensor with shape [C, T]. + waveform (Tensor): A waveform tensor with shape `(C, T)`. blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42. channel (int, optional): Select the channel of waveform. Defaults to -1. dither (float, optional): Dithering constant . Defaults to 0.0. @@ -239,7 +239,7 @@ def spectrogram(waveform: Tensor, window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. Returns: - Tensor: A spectrogram tensor with shape (m, padded_window_size // 2 + 1) where m is the number of frames + Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames depends on frame_length and frame_shift. """ dtype = waveform.dtype @@ -422,7 +422,7 @@ def fbank(waveform: Tensor, """Compute and return filter banks from a waveform. The output is identical to Kaldi's. Args: - waveform (Tensor): A waveform tensor with shape [C, T]. + waveform (Tensor): A waveform tensor with shape `(C, T)`. blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42. channel (int, optional): Select the channel of waveform. Defaults to -1. dither (float, optional): Dithering constant . Defaults to 0.0. @@ -451,7 +451,7 @@ def fbank(waveform: Tensor, window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. Returns: - Tensor: A filter banks tensor with shape (m, n_mels). + Tensor: A filter banks tensor with shape `(m, n_mels)`. """ dtype = waveform.dtype @@ -542,7 +542,7 @@ def mfcc(waveform: Tensor, identical to Kaldi's. Args: - waveform (Tensor): A waveform tensor with shape [C, T]. + waveform (Tensor): A waveform tensor with shape `(C, T)`. blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42. cepstral_lifter (float, optional): Scaling of output mfccs. Defaults to 22.0. channel (int, optional): Select the channel of waveform. Defaults to -1. @@ -571,7 +571,7 @@ def mfcc(waveform: Tensor, window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. Returns: - Tensor: A mel frequency cepstral coefficients tensor with shape (m, n_mfcc). + Tensor: A mel frequency cepstral coefficients tensor with shape `(m, n_mfcc)`. """ assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % ( n_mfcc, n_mels) diff --git a/paddleaudio/paddleaudio/compliance/librosa.py b/paddleaudio/paddleaudio/compliance/librosa.py index 167795c3701..740584ca5a2 100644 --- a/paddleaudio/paddleaudio/compliance/librosa.py +++ b/paddleaudio/paddleaudio/compliance/librosa.py @@ -19,7 +19,6 @@ import numpy as np import scipy -from numpy import ndarray as array from numpy.lib.stride_tricks import as_strided from scipy import signal @@ -32,7 +31,6 @@ 'mfcc', 'hz_to_mel', 'mel_to_hz', - 'split_frames', 'mel_frequencies', 'power_to_db', 'compute_fbank_matrix', @@ -49,7 +47,8 @@ ] -def pad_center(data: array, size: int, axis: int=-1, **kwargs) -> array: +def _pad_center(data: np.ndarray, size: int, axis: int=-1, + **kwargs) -> np.ndarray: """Pad an array to a target length along a target axis. This differs from `np.pad` by centering the data prior to padding, @@ -69,8 +68,10 @@ def pad_center(data: array, size: int, axis: int=-1, **kwargs) -> array: return np.pad(data, lengths, **kwargs) -def split_frames(x: array, frame_length: int, hop_length: int, - axis: int=-1) -> array: +def _split_frames(x: np.ndarray, + frame_length: int, + hop_length: int, + axis: int=-1) -> np.ndarray: """Slice a data array into (overlapping) frames. This function is aligned with librosa.frame @@ -142,11 +143,16 @@ def _check_audio(y, mono=True) -> bool: return True -def hz_to_mel(frequencies: Union[float, List[float], array], - htk: bool=False) -> array: - """Convert Hz to Mels +def hz_to_mel(frequencies: Union[float, List[float], np.ndarray], + htk: bool=False) -> np.ndarray: + """Convert Hz to Mels. - This function is aligned with librosa. + Args: + frequencies (Union[float, List[float], np.ndarray]): Frequencies in Hz. + htk (bool, optional): Use htk scaling. Defaults to False. + + Returns: + np.ndarray: Frequency in mels. """ freq = np.asanyarray(frequencies) @@ -177,10 +183,16 @@ def hz_to_mel(frequencies: Union[float, List[float], array], return mels -def mel_to_hz(mels: Union[float, List[float], array], htk: int=False) -> array: +def mel_to_hz(mels: Union[float, List[float], np.ndarray], + htk: int=False) -> np.ndarray: """Convert mel bin numbers to frequencies. - This function is aligned with librosa. + Args: + mels (Union[float, List[float], np.ndarray]): Frequency in mels. + htk (bool, optional): Use htk scaling. Defaults to False. + + Returns: + np.ndarray: Frequencies in Hz. """ mel_array = np.asanyarray(mels) @@ -212,10 +224,17 @@ def mel_to_hz(mels: Union[float, List[float], array], htk: int=False) -> array: def mel_frequencies(n_mels: int=128, fmin: float=0.0, fmax: float=11025.0, - htk: bool=False) -> array: - """Compute mel frequencies + htk: bool=False) -> np.ndarray: + """Compute mel frequencies. + + Args: + n_mels (int, optional): Number of mel bins. Defaults to 128. + fmin (float, optional): Minimum frequency in Hz. Defaults to 0.0. + fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0. + htk (bool, optional): Use htk scaling. Defaults to False. - This function is aligned with librosa. + Returns: + np.ndarray: Vector of n_mels frequencies in Hz with shape `(n_mels,)`. """ # 'Center freqs' of mel bands - uniformly spaced between limits min_mel = hz_to_mel(fmin, htk=htk) @@ -226,10 +245,15 @@ def mel_frequencies(n_mels: int=128, return mel_to_hz(mels, htk=htk) -def fft_frequencies(sr: int, n_fft: int) -> array: +def fft_frequencies(sr: int, n_fft: int) -> np.ndarray: """Compute fourier frequencies. - This function is aligned with librosa. + Args: + sr (int): Sample rate. + n_fft (int): FFT size. + + Returns: + np.ndarray: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`. """ return np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True) @@ -241,10 +265,22 @@ def compute_fbank_matrix(sr: int, fmax: Optional[float]=None, htk: bool=False, norm: str="slaney", - dtype: type=np.float32): + dtype: type=np.float32) -> np.ndarray: """Compute fbank matrix. - This funciton is aligned with librosa. + Args: + sr (int): Sample rate. + n_fft (int): FFT size. + n_mels (int, optional): Number of mel bins. Defaults to 128. + fmin (float, optional): Minimum frequency in Hz. Defaults to 0.0. + fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use htk scaling. Defaults to False. + norm (str, optional): Type of normalization. Defaults to "slaney". + dtype (type, optional): Data type. Defaults to np.float32. + + + Returns: + np.ndarray: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`. """ if norm != "slaney": raise ParameterError('norm must set to slaney') @@ -289,17 +325,28 @@ def compute_fbank_matrix(sr: int, return weights -def stft(x: array, +def stft(x: np.ndarray, n_fft: int=2048, hop_length: Optional[int]=None, win_length: Optional[int]=None, window: str="hann", center: bool=True, dtype: type=np.complex64, - pad_mode: str="reflect") -> array: + pad_mode: str="reflect") -> np.ndarray: """Short-time Fourier transform (STFT). - This function is aligned with librosa. + Args: + x (np.ndarray): Input waveform in one dimension. + n_fft (int, optional): FFT size. Defaults to 2048. + hop_length (Optional[int], optional): Number of steps to advance between adjacent windows. Defaults to None. + win_length (Optional[int], optional): The size of window. Defaults to None. + window (str, optional): A string of window specification. Defaults to "hann". + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + dtype (type, optional): Data type of STFT results. Defaults to np.complex64. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". + + Returns: + np.ndarray: The complex STFT output with shape `(n_fft//2 + 1, num_frames)`. """ _check_audio(x) @@ -314,7 +361,7 @@ def stft(x: array, fft_window = signal.get_window(window, win_length, fftbins=True) # Pad the window out to n_fft size - fft_window = pad_center(fft_window, n_fft) + fft_window = _pad_center(fft_window, n_fft) # Reshape so that the window can be broadcast fft_window = fft_window.reshape((-1, 1)) @@ -333,7 +380,7 @@ def stft(x: array, ) # Window the time series. - x_frames = split_frames(x, frame_length=n_fft, hop_length=hop_length) + x_frames = _split_frames(x, frame_length=n_fft, hop_length=hop_length) # Pre-allocate the STFT matrix stft_matrix = np.empty( (int(1 + n_fft // 2), x_frames.shape[1]), dtype=dtype, order="F") @@ -352,16 +399,20 @@ def stft(x: array, return stft_matrix -def power_to_db(spect: array, +def power_to_db(spect: np.ndarray, ref: float=1.0, amin: float=1e-10, - top_db: Optional[float]=80.0) -> array: - """Convert a power spectrogram (amplitude squared) to decibel (dB) units + top_db: Optional[float]=80.0) -> np.ndarray: + """Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way. - This computes the scaling ``10 * log10(spect / ref)`` in a numerically - stable way. + Args: + spect (np.ndarray): STFT power spectrogram of an input waveform. + ref (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. + amin (float, optional): Minimum threshold. Defaults to 1e-10. + top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to 80.0. - This function is aligned with librosa. + Returns: + np.ndarray: Power spectrogram in db scale. """ spect = np.asarray(spect) @@ -394,49 +445,27 @@ def power_to_db(spect: array, return log_spec -def mfcc(x, +def mfcc(x: np.ndarray, sr: int=16000, - spect: Optional[array]=None, + spect: Optional[np.ndarray]=None, n_mfcc: int=20, dct_type: int=2, norm: str="ortho", lifter: int=0, - **kwargs) -> array: + **kwargs) -> np.ndarray: """Mel-frequency cepstral coefficients (MFCCs) - This function is NOT strictly aligned with librosa. The following example shows how to get the - same result with librosa: - - # mfcc: - kwargs = { - 'window_size':512, - 'hop_length':320, - 'mel_bins':64, - 'fmin':50, - 'to_db':False} - a = mfcc(x, - spect=None, - n_mfcc=20, - dct_type=2, - norm='ortho', - lifter=0, - **kwargs) - - # librosa mfcc: - spect = librosa.feature.melspectrogram(y=x,sr=16000,n_fft=512, - win_length=512, - hop_length=320, - n_mels=64, fmin=50) - b = librosa.feature.mfcc(y=x, - sr=16000, - S=spect, - n_mfcc=20, - dct_type=2, - norm='ortho', - lifter=0) - - assert np.mean( (a-b)**2) < 1e-8 + Args: + x (np.ndarray): Input waveform in one dimension. + sr (int, optional): Sample rate. Defaults to 16000. + spect (Optional[np.ndarray], optional): Input log-power Mel spectrogram. Defaults to None. + n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 20. + dct_type (int, optional): Discrete cosine transform (DCT) type. Defaults to 2. + norm (str, optional): Type of normalization. Defaults to "ortho". + lifter (int, optional): Cepstral filtering. Defaults to 0. + Returns: + np.ndarray: Mel frequency cepstral coefficients array with shape `(n_mfcc, num_frames)`. """ if spect is None: spect = melspectrogram(x, sr=sr, **kwargs) @@ -454,12 +483,12 @@ def mfcc(x, f"MFCC lifter={lifter} must be a non-negative number") -def melspectrogram(x: array, +def melspectrogram(x: np.ndarray, sr: int=16000, window_size: int=512, hop_length: int=320, n_mels: int=64, - fmin: int=50, + fmin: float=50.0, fmax: Optional[float]=None, window: str='hann', center: bool=True, @@ -468,27 +497,28 @@ def melspectrogram(x: array, to_db: bool=True, ref: float=1.0, amin: float=1e-10, - top_db: Optional[float]=None) -> array: + top_db: Optional[float]=None) -> np.ndarray: """Compute mel-spectrogram. - Parameters: - x: numpy.ndarray - The input wavform is a numpy array [shape=(n,)] - - window_size: int, typically 512, 1024, 2048, etc. - The window size for framing, also used as n_fft for stft - + Args: + x (np.ndarray): Input waveform in one dimension. + sr (int, optional): Sample rate. Defaults to 16000. + window_size (int, optional): Size of FFT and window length. Defaults to 512. + hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320. + n_mels (int, optional): Number of mel bins. Defaults to 64. + fmin (float, optional): Minimum frequency in Hz. Defaults to 50.0. + fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + window (str, optional): A string of window specification. Defaults to "hann". + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". + power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0. + to_db (bool, optional): Enable db scale. Defaults to True. + ref (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. + amin (float, optional): Minimum threshold. Defaults to 1e-10. + top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None. Returns: - The mel-spectrogram in power scale or db scale(default) - - - Notes: - 1. sr is default to 16000, which is commonly used in speech/speaker processing. - 2. when fmax is None, it is set to sr//2. - 3. this function will convert mel spectgrum to db scale by default. This is different - that of librosa. - + np.ndarray: The mel-spectrogram in power scale or db scale with shape `(n_mels, num_frames)`. """ _check_audio(x, mono=True) if len(x) <= 0: @@ -518,18 +548,28 @@ def melspectrogram(x: array, return mel_spect -def spectrogram(x: array, +def spectrogram(x: np.ndarray, sr: int=16000, window_size: int=512, hop_length: int=320, window: str='hann', center: bool=True, pad_mode: str='reflect', - power: float=2.0) -> array: - """Compute spectrogram from an input waveform. + power: float=2.0) -> np.ndarray: + """Compute spectrogram. + + Args: + x (np.ndarray): Input waveform in one dimension. + sr (int, optional): Sample rate. Defaults to 16000. + window_size (int, optional): Size of FFT and window length. Defaults to 512. + hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320. + window (str, optional): A string of window specification. Defaults to "hann". + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". + power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0. - This function is a wrapper for librosa.feature.stft, with addition step to - compute the magnitude of the complex spectrogram. + Returns: + np.ndarray: The STFT spectrogram in power scale `(n_fft//2 + 1, num_frames)`. """ s = stft( @@ -544,18 +584,16 @@ def spectrogram(x: array, return np.abs(s)**power -def mu_encode(x: array, mu: int=255, quantized: bool=True) -> array: - """Mu-law encoding. - - Compute the mu-law decoding given an input code. - When quantized is True, the result will be converted to - integer in range [0,mu-1]. Otherwise, the resulting signal - is in range [-1,1] - +def mu_encode(x: np.ndarray, mu: int=255, quantized: bool=True) -> np.ndarray: + """Mu-law encoding. Encode waveform based on mu-law companding. When quantized is True, the result will be converted to integer in range `[0,mu-1]`. Otherwise, the resulting waveform is in range `[-1,1]`. - Reference: - https://en.wikipedia.org/wiki/%CE%9C-law_algorithm + Args: + x (np.ndarray): The input waveform to encode. + mu (int, optional): The endoceding parameter. Defaults to 255. + quantized (bool, optional): If `True`, quantize the encoded values into `1 + mu` distinct integer values. Defaults to True. + Returns: + np.ndarray: The mu-law encoded waveform. """ mu = 255 y = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) @@ -564,17 +602,16 @@ def mu_encode(x: array, mu: int=255, quantized: bool=True) -> array: return y -def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array: - """Mu-law decoding. - - Compute the mu-law decoding given an input code. +def mu_decode(y: np.ndarray, mu: int=255, quantized: bool=True) -> np.ndarray: + """Mu-law decoding. Compute the mu-law decoding given an input code. It assumes that the input `y` is in range `[0,mu-1]` when quantize is True and `[-1,1]` otherwise. - it assumes that the input y is in - range [0,mu-1] when quantize is True and [-1,1] otherwise - - Reference: - https://en.wikipedia.org/wiki/%CE%9C-law_algorithm + Args: + y (np.ndarray): The encoded waveform. + mu (int, optional): The endoceding parameter. Defaults to 255. + quantized (bool, optional): If `True`, the input is assumed to be quantized to `1 + mu` distinct integer values. Defaults to True. + Returns: + np.ndarray: The mu-law decoded waveform. """ if mu < 1: raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...') @@ -586,7 +623,7 @@ def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array: return x -def randint(high: int) -> int: +def _randint(high: int) -> int: """Generate one random integer in range [0 high) This is a helper function for random data augmentaiton @@ -594,20 +631,18 @@ def randint(high: int) -> int: return int(np.random.randint(0, high=high)) -def rand() -> float: - """Generate one floating-point number in range [0 1) - - This is a helper function for random data augmentaiton - """ - return float(np.random.rand(1)) - - -def depth_augment(y: array, +def depth_augment(y: np.ndarray, choices: List=['int8', 'int16'], - probs: List[float]=[0.5, 0.5]) -> array: - """ Audio depth augmentation + probs: List[float]=[0.5, 0.5]) -> np.ndarray: + """ Audio depth augmentation. Do audio depth augmentation to simulate the distortion brought by quantization. + + Args: + y (np.ndarray): Input waveform array in 1D or 2D. + choices (List, optional): A list of data type to depth conversion. Defaults to ['int8', 'int16']. + probs (List[float], optional): Probabilities to depth conversion. Defaults to [0.5, 0.5]. - Do audio depth augmentation to simulate the distortion brought by quantization. + Returns: + np.ndarray: The augmented waveform. """ assert len(probs) == len( choices @@ -621,13 +656,18 @@ def depth_augment(y: array, return y2 -def adaptive_spect_augment(spect: array, tempo_axis: int=0, - level: float=0.1) -> array: - """Do adpative spectrogram augmentation +def adaptive_spect_augment(spect: np.ndarray, + tempo_axis: int=0, + level: float=0.1) -> np.ndarray: + """Do adpative spectrogram augmentation. The level of the augmentation is gowern by the paramter level, ranging from 0 to 1, with 0 represents no augmentation. - The level of the augmentation is gowern by the paramter level, - ranging from 0 to 1, with 0 represents no augmentation。 + Args: + spect (np.ndarray): Input spectrogram. + tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0. + level (float, optional): The level factor of masking. Defaults to 0.1. + Returns: + np.ndarray: The augmented spectrogram. """ assert spect.ndim == 2., 'only supports 2d tensor or numpy array' if tempo_axis == 0: @@ -643,32 +683,40 @@ def adaptive_spect_augment(spect: array, tempo_axis: int=0, if tempo_axis == 0: for _ in range(num_time_mask): - start = randint(nt - time_mask_width) + start = _randint(nt - time_mask_width) spect[start:start + time_mask_width, :] = 0 for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) + start = _randint(nf - freq_mask_width) spect[:, start:start + freq_mask_width] = 0 else: for _ in range(num_time_mask): - start = randint(nt - time_mask_width) + start = _randint(nt - time_mask_width) spect[:, start:start + time_mask_width] = 0 for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) + start = _randint(nf - freq_mask_width) spect[start:start + freq_mask_width, :] = 0 return spect -def spect_augment(spect: array, +def spect_augment(spect: np.ndarray, tempo_axis: int=0, max_time_mask: int=3, max_freq_mask: int=3, max_time_mask_width: int=30, - max_freq_mask_width: int=20) -> array: - """Do spectrogram augmentation in both time and freq axis + max_freq_mask_width: int=20) -> np.ndarray: + """Do spectrogram augmentation in both time and freq axis. - Reference: + Args: + spect (np.ndarray): Input spectrogram. + tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0. + max_time_mask (int, optional): Maximum number of time masking. Defaults to 3. + max_freq_mask (int, optional): Maximum number of frenquence masking. Defaults to 3. + max_time_mask_width (int, optional): Maximum width of time masking. Defaults to 30. + max_freq_mask_width (int, optional): Maximum width of frenquence masking. Defaults to 20. + Returns: + np.ndarray: The augmented spectrogram. """ assert spect.ndim == 2., 'only supports 2d tensor or numpy array' if tempo_axis == 0: @@ -676,52 +724,64 @@ def spect_augment(spect: array, else: nf, nt = spect.shape - num_time_mask = randint(max_time_mask) - num_freq_mask = randint(max_freq_mask) + num_time_mask = _randint(max_time_mask) + num_freq_mask = _randint(max_freq_mask) - time_mask_width = randint(max_time_mask_width) - freq_mask_width = randint(max_freq_mask_width) + time_mask_width = _randint(max_time_mask_width) + freq_mask_width = _randint(max_freq_mask_width) if tempo_axis == 0: for _ in range(num_time_mask): - start = randint(nt - time_mask_width) + start = _randint(nt - time_mask_width) spect[start:start + time_mask_width, :] = 0 for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) + start = _randint(nf - freq_mask_width) spect[:, start:start + freq_mask_width] = 0 else: for _ in range(num_time_mask): - start = randint(nt - time_mask_width) + start = _randint(nt - time_mask_width) spect[:, start:start + time_mask_width] = 0 for _ in range(num_freq_mask): - start = randint(nf - freq_mask_width) + start = _randint(nf - freq_mask_width) spect[start:start + freq_mask_width, :] = 0 return spect -def random_crop1d(y: array, crop_len: int) -> array: - """ Do random cropping on 1d input signal +def random_crop1d(y: np.ndarray, crop_len: int) -> np.ndarray: + """ Random cropping on a input waveform. - The input is a 1d signal, typically a sound waveform + Args: + y (np.ndarray): Input waveform array in 1D. + crop_len (int): Length of waveform to crop. + + Returns: + np.ndarray: The cropped waveform. """ if y.ndim != 1: 'only accept 1d tensor or numpy array' n = len(y) - idx = randint(n - crop_len) + idx = _randint(n - crop_len) return y[idx:idx + crop_len] -def random_crop2d(s: array, crop_len: int, tempo_axis: int=0) -> array: - """ Do random cropping for 2D array, typically a spectrogram. +def random_crop2d(s: np.ndarray, crop_len: int, + tempo_axis: int=0) -> np.ndarray: + """ Random cropping on a spectrogram. - The cropping is done in temporal direction on the time-freq input signal. + Args: + s (np.ndarray): Input spectrogram in 2D. + crop_len (int): Length of spectrogram to crop. + tempo_axis (int, optional): Indicate the tempo axis. Defaults to 0. + + Returns: + np.ndarray: The cropped spectrogram. """ if tempo_axis >= s.ndim: raise ParameterError('axis out of range') n = s.shape[tempo_axis] - idx = randint(high=n - crop_len) + idx = _randint(high=n - crop_len) sli = [slice(None) for i in range(s.ndim)] sli[tempo_axis] = slice(idx, idx + crop_len) out = s[tuple(sli)] diff --git a/paddleaudio/paddleaudio/features/layers.py b/paddleaudio/paddleaudio/features/layers.py index 4a2c1673a02..09037255ddf 100644 --- a/paddleaudio/paddleaudio/features/layers.py +++ b/paddleaudio/paddleaudio/features/layers.py @@ -17,6 +17,7 @@ import paddle import paddle.nn as nn +from paddle import Tensor from ..functional import compute_fbank_matrix from ..functional import create_dct @@ -32,42 +33,34 @@ class Spectrogram(nn.Layer): + """Compute spectrogram of given signals, typically audio waveforms. + The spectorgram is defined as the complex norm of the short-time Fourier transformation. + + Args: + n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. + hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. + win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. + dtype (str, optional): Data type of input and window. Defaults to 'float32'. + """ + def __init__(self, n_fft: int=512, hop_length: Optional[int]=None, win_length: Optional[int]=None, window: str='hann', + power: float=2.0, center: bool=True, pad_mode: str='reflect', - dtype: str=paddle.float32): - """Compute spectrogram of a given signal, typically an audio waveform. - The spectorgram is defined as the complex norm of the short-time - Fourier transformation. - Parameters: - n_fft (int): the number of frequency components of the discrete Fourier transform. - The default value is 2048, - hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4. - The default value is None. - win_length: the window length of the short time FFt. If None, it is set to same as n_fft. - The default value is None. - window (str): the name of the window function applied to the single before the Fourier transform. - The folllowing window names are supported: 'hamming','hann','kaiser','gaussian', - 'exponential','triang','bohman','blackman','cosine','tukey','taylor'. - The default value is 'hann' - center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length]. - If False, frame t begins at x[t * hop_length] - The default value is True - pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect' - and 'constant'. The default value is 'reflect'. - dtype (str): the data type of input and window. - Notes: - The Spectrogram transform relies on STFT transform to compute the spectrogram. - By default, the weights are not learnable. To fine-tune the Fourier coefficients, - set stop_gradient=False before training. - For more information, see STFT(). - """ + dtype: str='float32') -> None: super(Spectrogram, self).__init__() + assert power > 0, 'Power of spectrogram must be > 0.' + self.power = power + if win_length is None: win_length = n_fft @@ -83,19 +76,46 @@ def __init__(self, pad_mode=pad_mode) self.register_buffer('fft_window', self.fft_window) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Tensor of waveforms with shape `(N, T)` + + Returns: + Tensor: Spectrograms with shape `(N, n_fft//2 + 1, num_frames)`. + """ stft = self._stft(x) - spectrogram = paddle.square(paddle.abs(stft)) + spectrogram = paddle.pow(paddle.abs(stft), self.power) return spectrogram class MelSpectrogram(nn.Layer): + """Compute the melspectrogram of given signals, typically audio waveforms. It is computed by multiplying spectrogram with Mel filter bank matrix. + + Args: + sr (int, optional): Sample rate. Defaults to 22050. + n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. + hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. + win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. + f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use HTK formula in computing fbank matrix. Defaults to False. + norm (Union[str, float], optional): Type of normalization in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. Defaults to 'slaney'. + dtype (str, optional): Data type of input and window. Defaults to 'float32'. + """ + def __init__(self, sr: int=22050, n_fft: int=512, hop_length: Optional[int]=None, win_length: Optional[int]=None, window: str='hann', + power: float=2.0, center: bool=True, pad_mode: str='reflect', n_mels: int=64, @@ -103,38 +123,7 @@ def __init__(self, f_max: Optional[float]=None, htk: bool=False, norm: Union[str, float]='slaney', - dtype: str=paddle.float32): - """Compute the melspectrogram of a given signal, typically an audio waveform. - The melspectrogram is also known as filterbank or fbank feature in audio community. - It is computed by multiplying spectrogram with Mel filter bank matrix. - Parameters: - sr(int): the audio sample rate. - The default value is 22050. - n_fft(int): the number of frequency components of the discrete Fourier transform. - The default value is 2048, - hop_length(int|None): the hop length of the short time FFT. If None, it is set to win_length//4. - The default value is None. - win_length: the window length of the short time FFt. If None, it is set to same as n_fft. - The default value is None. - window(str): the name of the window function applied to the single before the Fourier transform. - The folllowing window names are supported: 'hamming','hann','kaiser','gaussian', - 'exponential','triang','bohman','blackman','cosine','tukey','taylor'. - The default value is 'hann' - center(bool): if True, the signal is padded so that frame t is centered at x[t * hop_length]. - If False, frame t begins at x[t * hop_length] - The default value is True - pad_mode(str): the mode to pad the signal if necessary. The supported modes are 'reflect' - and 'constant'. - The default value is 'reflect'. - n_mels(int): the mel bins. - f_min(float): the lower cut-off frequency, below which the filter response is zero. - f_max(float): the upper cut-off frequency, above which the filter response is zeros. - htk(bool): whether to use HTK formula in computing fbank matrix. - norm(str|float): the normalization type in computing fbank matrix. Slaney-style is used by default. - You can specify norm=1.0/2.0 to use customized p-norm normalization. - dtype(str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical - accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix. - """ + dtype: str='float32') -> None: super(MelSpectrogram, self).__init__() self._spectrogram = Spectrogram( @@ -142,6 +131,7 @@ def __init__(self, hop_length=hop_length, win_length=win_length, window=window, + power=power, center=center, pad_mode=pad_mode, dtype=dtype) @@ -163,19 +153,49 @@ def __init__(self, dtype=dtype) # float64 for better numerical results self.register_buffer('fbank_matrix', self.fbank_matrix) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Tensor of waveforms with shape `(N, T)` + + Returns: + Tensor: Mel spectrograms with shape `(N, n_mels, num_frames)`. + """ spect_feature = self._spectrogram(x) mel_feature = paddle.matmul(self.fbank_matrix, spect_feature) return mel_feature class LogMelSpectrogram(nn.Layer): + """Compute log-mel-spectrogram feature of given signals, typically audio waveforms. + + Args: + sr (int, optional): Sample rate. Defaults to 22050. + n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. + hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. + win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. + f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use HTK formula in computing fbank matrix. Defaults to False. + norm (Union[str, float], optional): Type of normalization in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. Defaults to 'slaney'. + ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. + amin (float, optional): The minimum value of input magnitude. Defaults to 1e-10. + top_db (Optional[float], optional): The maximum db value of spectrogram. Defaults to None. + dtype (str, optional): Data type of input and window. Defaults to 'float32'. + """ + def __init__(self, sr: int=22050, n_fft: int=512, hop_length: Optional[int]=None, win_length: Optional[int]=None, window: str='hann', + power: float=2.0, center: bool=True, pad_mode: str='reflect', n_mels: int=64, @@ -186,44 +206,7 @@ def __init__(self, ref_value: float=1.0, amin: float=1e-10, top_db: Optional[float]=None, - dtype: str=paddle.float32): - """Compute log-mel-spectrogram(also known as LogFBank) feature of a given signal, - typically an audio waveform. - Parameters: - sr (int): the audio sample rate. - The default value is 22050. - n_fft (int): the number of frequency components of the discrete Fourier transform. - The default value is 2048, - hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4. - The default value is None. - win_length: the window length of the short time FFt. If None, it is set to same as n_fft. - The default value is None. - window (str): the name of the window function applied to the single before the Fourier transform. - The folllowing window names are supported: 'hamming','hann','kaiser','gaussian', - 'exponential','triang','bohman','blackman','cosine','tukey','taylor'. - The default value is 'hann' - center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length]. - If False, frame t begins at x[t * hop_length] - The default value is True - pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect' - and 'constant'. - The default value is 'reflect'. - n_mels (int): the mel bins. - f_min (float): the lower cut-off frequency, below which the filter response is zero. - f_max (float): the upper cut-off frequency, above which the filter response is zeros. - htk (bool): whether to use HTK formula in computing fbank matrix. - norm (str|float): the normalization type in computing fbank matrix. Slaney-style is used by default. - You can specify norm=1.0/2.0 to use customized p-norm normalization. - ref_value (float): the reference value. If smaller than 1.0, the db level - amin (float): the minimum value of input magnitude, below which the input of the signal will be pulled up accordingly. - Otherwise, the db level is pushed down. - magnitude is clipped(to amin). For numerical stability, set amin to a larger value, - e.g., 1e-3. - top_db (float): the maximum db value of resulting spectrum, above which the - spectrum is clipped(to top_db). - dtype (str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical - accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix. - """ + dtype: str='float32') -> None: super(LogMelSpectrogram, self).__init__() self._melspectrogram = MelSpectrogram( @@ -232,6 +215,7 @@ def __init__(self, hop_length=hop_length, win_length=win_length, window=window, + power=power, center=center, pad_mode=pad_mode, n_mels=n_mels, @@ -245,8 +229,14 @@ def __init__(self, self.amin = amin self.top_db = top_db - def forward(self, x): - # import ipdb; ipdb.set_trace() + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Tensor of waveforms with shape `(N, T)` + + Returns: + Tensor: Log mel spectrograms with shape `(N, n_mels, num_frames)`. + """ mel_feature = self._melspectrogram(x) log_mel_feature = power_to_db( mel_feature, @@ -257,6 +247,29 @@ def forward(self, x): class MFCC(nn.Layer): + """Compute mel frequency cepstral coefficients(MFCCs) feature of given waveforms. + + Args: + sr (int, optional): Sample rate. Defaults to 22050. + n_mfcc (int, optional): [description]. Defaults to 40. + n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. + hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. + win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. + f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use HTK formula in computing fbank matrix. Defaults to False. + norm (Union[str, float], optional): Type of normalization in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. Defaults to 'slaney'. + ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. + amin (float, optional): The minimum value of input magnitude. Defaults to 1e-10. + top_db (Optional[float], optional): The maximum db value of spectrogram. Defaults to None. + dtype (str, optional): Data type of input and window. Defaults to 'float32'. + """ + def __init__(self, sr: int=22050, n_mfcc: int=40, @@ -264,6 +277,7 @@ def __init__(self, hop_length: Optional[int]=None, win_length: Optional[int]=None, window: str='hann', + power: float=2.0, center: bool=True, pad_mode: str='reflect', n_mels: int=64, @@ -274,45 +288,7 @@ def __init__(self, ref_value: float=1.0, amin: float=1e-10, top_db: Optional[float]=None, - dtype: str=paddle.float32): - """Compute mel frequency cepstral coefficients(MFCCs) feature of given waveforms. - - Parameters: - sr(int): the audio sample rate. - The default value is 22050. - n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 40. - n_fft (int): the number of frequency components of the discrete Fourier transform. - The default value is 2048, - hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4. - The default value is None. - win_length: the window length of the short time FFt. If None, it is set to same as n_fft. - The default value is None. - window (str): the name of the window function applied to the single before the Fourier transform. - The folllowing window names are supported: 'hamming','hann','kaiser','gaussian', - 'exponential','triang','bohman','blackman','cosine','tukey','taylor'. - The default value is 'hann' - center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length]. - If False, frame t begins at x[t * hop_length] - The default value is True - pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect' - and 'constant'. - The default value is 'reflect'. - n_mels (int): the mel bins. - f_min (float): the lower cut-off frequency, below which the filter response is zero. - f_max (float): the upper cut-off frequency, above which the filter response is zeros. - htk (bool): whether to use HTK formula in computing fbank matrix. - norm (str|float): the normalization type in computing fbank matrix. Slaney-style is used by default. - You can specify norm=1.0/2.0 to use customized p-norm normalization. - ref_value (float): the reference value. If smaller than 1.0, the db level - amin (float): the minimum value of input magnitude, below which the input of the signal will be pulled up accordingly. - Otherwise, the db level is pushed down. - magnitude is clipped(to amin). For numerical stability, set amin to a larger value, - e.g., 1e-3. - top_db (float): the maximum db value of resulting spectrum, above which the - spectrum is clipped(to top_db). - dtype (str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical - accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix. - """ + dtype: str=paddle.float32) -> None: super(MFCC, self).__init__() assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % ( n_mfcc, n_mels) @@ -322,6 +298,7 @@ def __init__(self, hop_length=hop_length, win_length=win_length, window=window, + power=power, center=center, pad_mode=pad_mode, n_mels=n_mels, @@ -336,7 +313,14 @@ def __init__(self, self.dct_matrix = create_dct(n_mfcc=n_mfcc, n_mels=n_mels, dtype=dtype) self.register_buffer('dct_matrix', self.dct_matrix) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Tensor of waveforms with shape `(N, T)` + + Returns: + Tensor: Mel frequency cepstral coefficients with shape `(N, n_mfcc, num_frames)`. + """ log_mel_feature = self._log_melspectrogram(x) mfcc = paddle.matmul( log_mel_feature.transpose((0, 2, 1)), self.dct_matrix).transpose( diff --git a/paddleaudio/paddleaudio/functional/functional.py b/paddleaudio/paddleaudio/functional/functional.py index c5ab30453e6..19c63a9aef2 100644 --- a/paddleaudio/paddleaudio/functional/functional.py +++ b/paddleaudio/paddleaudio/functional/functional.py @@ -17,6 +17,7 @@ from typing import Union import paddle +from paddle import Tensor __all__ = [ 'hz_to_mel', @@ -29,19 +30,20 @@ ] -def hz_to_mel(freq: Union[paddle.Tensor, float], - htk: bool=False) -> Union[paddle.Tensor, float]: +def hz_to_mel(freq: Union[Tensor, float], + htk: bool=False) -> Union[Tensor, float]: """Convert Hz to Mels. - Parameters: - freq: the input tensor of arbitrary shape, or a single floating point number. - htk: use HTK formula to do the conversion. - The default value is False. + + Args: + freq (Union[Tensor, float]): The input tensor with arbitrary shape. + htk (bool, optional): Use htk scaling. Defaults to False. + Returns: - The frequencies represented in Mel-scale. + Union[Tensor, float]: Frequency in mels. """ if htk: - if isinstance(freq, paddle.Tensor): + if isinstance(freq, Tensor): return 2595.0 * paddle.log10(1.0 + freq / 700.0) else: return 2595.0 * math.log10(1.0 + freq / 700.0) @@ -58,7 +60,7 @@ def hz_to_mel(freq: Union[paddle.Tensor, float], min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) logstep = math.log(6.4) / 27.0 # step size for log region - if isinstance(freq, paddle.Tensor): + if isinstance(freq, Tensor): target = min_log_mel + paddle.log( freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10 mask = (freq > min_log_hz).astype(freq.dtype) @@ -71,14 +73,16 @@ def hz_to_mel(freq: Union[paddle.Tensor, float], return mels -def mel_to_hz(mel: Union[float, paddle.Tensor], - htk: bool=False) -> Union[float, paddle.Tensor]: +def mel_to_hz(mel: Union[float, Tensor], + htk: bool=False) -> Union[float, Tensor]: """Convert mel bin numbers to frequencies. - Parameters: - mel: the mel frequency represented as a tensor of arbitrary shape, or a floating point number. - htk: use HTK formula to do the conversion. + + Args: + mel (Union[float, Tensor]): The mel frequency represented as a tensor with arbitrary shape. + htk (bool, optional): Use htk scaling. Defaults to False. + Returns: - The frequencies represented in hz. + Union[float, Tensor]: Frequencies in Hz. """ if htk: return 700.0 * (10.0**(mel / 2595.0) - 1.0) @@ -90,7 +94,7 @@ def mel_to_hz(mel: Union[float, paddle.Tensor], min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) logstep = math.log(6.4) / 27.0 # step size for log region - if isinstance(mel, paddle.Tensor): + if isinstance(mel, Tensor): target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel)) mask = (mel > min_log_mel).astype(mel.dtype) freqs = target * mask + freqs * ( @@ -106,16 +110,18 @@ def mel_frequencies(n_mels: int=64, f_min: float=0.0, f_max: float=11025.0, htk: bool=False, - dtype: str=paddle.float32): + dtype: str='float32') -> Tensor: """Compute mel frequencies. - Parameters: - n_mels(int): number of Mel bins. - f_min(float): the lower cut-off frequency, below which the filter response is zero. - f_max(float): the upper cut-off frequency, above which the filter response is zero. - htk(bool): whether to use htk formula. - dtype(str): the datatype of the return frequencies. + + Args: + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0. + fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0. + htk (bool, optional): Use htk scaling. Defaults to False. + dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'. + Returns: - The frequencies represented in Mel-scale + Tensor: Tensor of n_mels frequencies in Hz with shape `(n_mels,)`. """ # 'Center freqs' of mel bands - uniformly spaced between limits min_mel = hz_to_mel(f_min, htk=htk) @@ -125,14 +131,16 @@ def mel_frequencies(n_mels: int=64, return freqs -def fft_frequencies(sr: int, n_fft: int, dtype: str=paddle.float32): +def fft_frequencies(sr: int, n_fft: int, dtype: str='float32') -> Tensor: """Compute fourier frequencies. - Parameters: - sr(int): the audio sample rate. - n_fft(float): the number of fft bins. - dtype(str): the datatype of the return frequencies. + + Args: + sr (int): Sample rate. + n_fft (int): Number of fft bins. + dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'. + Returns: - The frequencies represented in hz. + Tensor: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`. """ return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype) @@ -144,23 +152,21 @@ def compute_fbank_matrix(sr: int, f_max: Optional[float]=None, htk: bool=False, norm: Union[str, float]='slaney', - dtype: str=paddle.float32): + dtype: str='float32') -> Tensor: """Compute fbank matrix. - Parameters: - sr(int): the audio sample rate. - n_fft(int): the number of fft bins. - n_mels(int): the number of Mel bins. - f_min(float): the lower cut-off frequency, below which the filter response is zero. - f_max(float): the upper cut-off frequency, above which the filter response is zero. - htk: whether to use htk formula. - return_complex(bool): whether to return complex matrix. If True, the matrix will - be complex type. Otherwise, the real and image part will be stored in the last - axis of returned tensor. - dtype(str): the datatype of the returned fbank matrix. + + Args: + sr (int): Sample rate. + n_fft (int): Number of fft bins. + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0. + f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use htk scaling. Defaults to False. + norm (Union[str, float], optional): Type of normalization. Defaults to 'slaney'. + dtype (str, optional): The data type of the return matrix. Defaults to 'float32'. + Returns: - The fbank matrix of shape (n_mels, int(1+n_fft//2)). - Shape: - output: (n_mels, int(1+n_fft//2)) + Tensor: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`. """ if f_max is None: @@ -199,27 +205,20 @@ def compute_fbank_matrix(sr: int, return weights -def power_to_db(magnitude: paddle.Tensor, +def power_to_db(spect: Tensor, ref_value: float=1.0, amin: float=1e-10, - top_db: Optional[float]=None) -> paddle.Tensor: - """Convert a power spectrogram (amplitude squared) to decibel (dB) units. - The function computes the scaling ``10 * log10(x / ref)`` in a numerically - stable way. - Parameters: - magnitude(Tensor): the input magnitude tensor of any shape. - ref_value(float): the reference value. If smaller than 1.0, the db level - of the signal will be pulled up accordingly. Otherwise, the db level - is pushed down. - amin(float): the minimum value of input magnitude, below which the input - magnitude is clipped(to amin). - top_db(float): the maximum db value of resulting spectrum, above which the - spectrum is clipped(to top_db). + top_db: Optional[float]=None) -> Tensor: + """Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way. + + Args: + spect (Tensor): STFT power spectrogram. + ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. + amin (float, optional): Minimum threshold. Defaults to 1e-10. + top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None. + Returns: - The spectrogram in log-scale. - shape: - input: any shape - output: same as input + Tensor: Power spectrogram in db scale. """ if amin <= 0: raise Exception("amin must be strictly positive") @@ -227,8 +226,8 @@ def power_to_db(magnitude: paddle.Tensor, if ref_value <= 0: raise Exception("ref_value must be strictly positive") - ones = paddle.ones_like(magnitude) - log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, magnitude)) + ones = paddle.ones_like(spect) + log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, spect)) log_spec -= 10.0 * math.log10(max(ref_value, amin)) if top_db is not None: @@ -242,15 +241,17 @@ def power_to_db(magnitude: paddle.Tensor, def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]='ortho', - dtype: Optional[str]=paddle.float32) -> paddle.Tensor: + dtype: str='float32') -> Tensor: """Create a discrete cosine transform(DCT) matrix. - Parameters: + Args: n_mfcc (int): Number of mel frequency cepstral coefficients. n_mels (int): Number of mel filterbanks. - norm (str, optional): Normalizaiton type. Defaults to 'ortho'. + norm (Optional[str], optional): Normalizaiton type. Defaults to 'ortho'. + dtype (str, optional): The data type of the return matrix. Defaults to 'float32'. + Returns: - Tensor: The DCT matrix with shape (n_mels, n_mfcc). + Tensor: The DCT matrix with shape `(n_mels, n_mfcc)`. """ n = paddle.arange(n_mels, dtype=dtype) k = paddle.arange(n_mfcc, dtype=dtype).unsqueeze(1) diff --git a/paddleaudio/paddleaudio/functional/window.py b/paddleaudio/paddleaudio/functional/window.py index f321b38efa3..c99d50462e3 100644 --- a/paddleaudio/paddleaudio/functional/window.py +++ b/paddleaudio/paddleaudio/functional/window.py @@ -20,24 +20,11 @@ __all__ = [ 'get_window', - - # windows - 'taylor', - 'hamming', - 'hann', - 'tukey', - 'kaiser', - 'gaussian', - 'exponential', - 'triang', - 'bohman', - 'blackman', - 'cosine', ] -def _cat(a: List[Tensor], data_type: str) -> Tensor: - l = [paddle.to_tensor(_a, data_type) for _a in a] +def _cat(x: List[Tensor], data_type: str) -> Tensor: + l = [paddle.to_tensor(_, data_type) for _ in x] return paddle.concat(l) @@ -48,7 +35,7 @@ def _acosh(x: Union[Tensor, float]) -> Tensor: def _extend(M: int, sym: bool) -> bool: - """Extend window by 1 sample if needed for DFT-even symmetry""" + """Extend window by 1 sample if needed for DFT-even symmetry. """ if not sym: return M + 1, True else: @@ -56,7 +43,7 @@ def _extend(M: int, sym: bool) -> bool: def _len_guards(M: int) -> bool: - """Handle small or incorrect window lengths""" + """Handle small or incorrect window lengths. """ if int(M) != M or M < 0: raise ValueError('Window length M must be a non-negative integer') @@ -64,15 +51,15 @@ def _len_guards(M: int) -> bool: def _truncate(w: Tensor, needed: bool) -> Tensor: - """Truncate window by 1 sample if needed for DFT-even symmetry""" + """Truncate window by 1 sample if needed for DFT-even symmetry. """ if needed: return w[:-1] else: return w -def general_gaussian(M: int, p, sig, sym: bool=True, - dtype: str='float64') -> Tensor: +def _general_gaussian(M: int, p, sig, sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a window with a generalized Gaussian shape. This function is consistent with scipy.signal.windows.general_gaussian(). """ @@ -86,8 +73,8 @@ def general_gaussian(M: int, p, sig, sym: bool=True, return _truncate(w, needs_trunc) -def general_cosine(M: int, a: float, sym: bool=True, - dtype: str='float64') -> Tensor: +def _general_cosine(M: int, a: float, sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a generic weighted sum of cosine terms window. This function is consistent with scipy.signal.windows.general_cosine(). """ @@ -101,31 +88,23 @@ def general_cosine(M: int, a: float, sym: bool=True, return _truncate(w, needs_trunc) -def general_hamming(M: int, alpha: float, sym: bool=True, - dtype: str='float64') -> Tensor: +def _general_hamming(M: int, alpha: float, sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a generalized Hamming window. This function is consistent with scipy.signal.windows.general_hamming() """ - return general_cosine(M, [alpha, 1. - alpha], sym, dtype=dtype) + return _general_cosine(M, [alpha, 1. - alpha], sym, dtype=dtype) -def taylor(M: int, - nbar=4, - sll=30, - norm=True, - sym: bool=True, - dtype: str='float64') -> Tensor: +def _taylor(M: int, + nbar=4, + sll=30, + norm=True, + sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a Taylor window. The Taylor window taper function approximates the Dolph-Chebyshev window's constant sidelobe level for a parameterized number of near-in sidelobes. - Parameters: - M(int): window size - nbar, sil, norm: the window-specific parameter. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -171,46 +150,25 @@ def W(n): return _truncate(w, needs_trunc) -def hamming(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _hamming(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a Hamming window. The Hamming window is a taper formed by using a raised cosine with non-zero endpoints, optimized to minimize the nearest side lobe. - Parameters: - M(int): window size - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ - return general_hamming(M, 0.54, sym, dtype=dtype) + return _general_hamming(M, 0.54, sym, dtype=dtype) -def hann(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _hann(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a Hann window. The Hann window is a taper formed by using a raised cosine or sine-squared with ends that touch zero. - Parameters: - M(int): window size - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ - return general_hamming(M, 0.5, sym, dtype=dtype) + return _general_hamming(M, 0.5, sym, dtype=dtype) -def tukey(M: int, alpha=0.5, sym: bool=True, dtype: str='float64') -> Tensor: +def _tukey(M: int, alpha=0.5, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a Tukey window. The Tukey window is also known as a tapered cosine window. - Parameters: - M(int): window size - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -237,32 +195,18 @@ def tukey(M: int, alpha=0.5, sym: bool=True, dtype: str='float64') -> Tensor: return _truncate(w, needs_trunc) -def kaiser(M: int, beta: float, sym: bool=True, dtype: str='float64') -> Tensor: +def _kaiser(M: int, beta: float, sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a Kaiser window. The Kaiser window is a taper formed by using a Bessel function. - Parameters: - M(int): window size. - beta(float): the window-specific parameter. - sym(bool):whether to return symmetric window. - The default value is True - Returns: - Tensor: the window tensor """ raise NotImplementedError() -def gaussian(M: int, std: float, sym: bool=True, - dtype: str='float64') -> Tensor: +def _gaussian(M: int, std: float, sym: bool=True, + dtype: str='float64') -> Tensor: """Compute a Gaussian window. The Gaussian widows has a Gaussian shape defined by the standard deviation(std). - Parameters: - M(int): window size. - std(float): the window-specific parameter. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -275,21 +219,12 @@ def gaussian(M: int, std: float, sym: bool=True, return _truncate(w, needs_trunc) -def exponential(M: int, - center=None, - tau=1., - sym: bool=True, - dtype: str='float64') -> Tensor: - """Compute an exponential (or Poisson) window. - Parameters: - M(int): window size. - tau(float): the window-specific parameter. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor - """ +def _exponential(M: int, + center=None, + tau=1., + sym: bool=True, + dtype: str='float64') -> Tensor: + """Compute an exponential (or Poisson) window. """ if sym and center is not None: raise ValueError("If sym==True, center must be None.") if _len_guards(M): @@ -305,15 +240,8 @@ def exponential(M: int, return _truncate(w, needs_trunc) -def triang(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _triang(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a triangular window. - Parameters: - M(int): window size. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -330,16 +258,9 @@ def triang(M: int, sym: bool=True, dtype: str='float64') -> Tensor: return _truncate(w, needs_trunc) -def bohman(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _bohman(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a Bohman window. The Bohman window is the autocorrelation of a cosine window. - Parameters: - M(int): window size. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -353,32 +274,18 @@ def bohman(M: int, sym: bool=True, dtype: str='float64') -> Tensor: return _truncate(w, needs_trunc) -def blackman(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _blackman(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a Blackman window. The Blackman window is a taper formed by using the first three terms of a summation of cosines. It was designed to have close to the minimal leakage possible. It is close to optimal, only slightly worse than a Kaiser window. - Parameters: - M(int): window size. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ - return general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype) + return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype) -def cosine(M: int, sym: bool=True, dtype: str='float64') -> Tensor: +def _cosine(M: int, sym: bool=True, dtype: str='float64') -> Tensor: """Compute a window with a simple cosine shape. - Parameters: - M(int): window size. - sym(bool):whether to return symmetric window. - The default value is True - dtype(str): the datatype of returned tensor. - Returns: - Tensor: the window tensor """ if _len_guards(M): return paddle.ones((M, ), dtype=dtype) @@ -388,19 +295,20 @@ def cosine(M: int, sym: bool=True, dtype: str='float64') -> Tensor: return _truncate(w, needs_trunc) -## factory function def get_window(window: Union[str, Tuple[str, float]], win_length: int, fftbins: bool=True, dtype: str='float64') -> Tensor: """Return a window of a given length and type. - Parameters: - window(str|(str,float)): the type of window to create. - win_length(int): the number of samples in the window. - fftbins(bool): If True, create a "periodic" window. Otherwise, - create a "symmetric" window, for use in filter design. + + Args: + window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. + win_length (int): Number of samples. + fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True. + dtype (str, optional): The data type of the return window. Defaults to 'float64'. + Returns: - The window represented as a tensor. + Tensor: The window represented as a tensor. """ sym = not fftbins @@ -420,7 +328,7 @@ def get_window(window: Union[str, Tuple[str, float]], str(type(window))) try: - winfunc = eval(winstr) + winfunc = eval('_' + winstr) except KeyError as e: raise ValueError("Unknown window type.") from e diff --git a/paddleaudio/paddleaudio/metric/dtw.py b/paddleaudio/paddleaudio/metric/dtw.py index d27f56e2832..c4dc7a283d9 100644 --- a/paddleaudio/paddleaudio/metric/dtw.py +++ b/paddleaudio/paddleaudio/metric/dtw.py @@ -20,9 +20,7 @@ def dtw_distance(xs: np.ndarray, ys: np.ndarray) -> float: - """dtw distance - - Dynamic Time Warping. + """Dynamic Time Warping. This function keeps a compact matrix, not the full warping paths matrix. Uses dynamic programming to compute: diff --git a/paddleaudio/setup.py b/paddleaudio/setup.py index 7623443a68b..930f86e41e5 100644 --- a/paddleaudio/setup.py +++ b/paddleaudio/setup.py @@ -11,19 +11,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob +import os + import setuptools +from setuptools.command.install import install +from setuptools.command.test import test # set the version here VERSION = '0.2.0' +# Inspired by the example at https://pytest.org/latest/goodpractises.html +class TestCommand(test): + def finalize_options(self): + test.finalize_options(self) + self.test_args = [] + self.test_suite = True + + def run(self): + self.run_benchmark() + super(TestCommand, self).run() + + def run_tests(self): + # Run nose ensuring that argv simulates running nosetests directly + import nose + nose.run_exit(argv=['nosetests', '-w', 'tests']) + + def run_benchmark(self): + for benchmark_item in glob.glob('tests/benchmark/*py'): + os.system(f'pytest {benchmark_item}') + + +class InstallCommand(install): + def run(self): + install.run(self) + + def write_version_py(filename='paddleaudio/__init__.py'): - import paddleaudio - if hasattr(paddleaudio, - "__version__") and paddleaudio.__version__ == VERSION: - return with open(filename, "a") as f: - f.write(f"\n__version__ = '{VERSION}'\n") + f.write(f"__version__ = '{VERSION}'") def remove_version_py(filename='paddleaudio/__init__.py'): @@ -35,6 +62,7 @@ def remove_version_py(filename='paddleaudio/__init__.py'): f.write(line) +remove_version_py() write_version_py() setuptools.setup( @@ -61,6 +89,16 @@ def remove_version_py(filename='paddleaudio/__init__.py'): 'colorlog', 'dtaidistance >= 2.3.6', 'mcd >= 0.4', - ], ) + ], + extras_require={ + 'test': [ + 'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1', + 'torchaudio==0.10.2', 'pytest-benchmark' + ], + }, + cmdclass={ + 'install': InstallCommand, + 'test': TestCommand, + }, ) remove_version_py() diff --git a/paddleaudio/tests/backends/__init__.py b/paddleaudio/tests/backends/__init__.py new file mode 100644 index 00000000000..97043fd7ba6 --- /dev/null +++ b/paddleaudio/tests/backends/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddleaudio/tests/backends/base.py b/paddleaudio/tests/backends/base.py new file mode 100644 index 00000000000..a67191887ff --- /dev/null +++ b/paddleaudio/tests/backends/base.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +import urllib.request + +mono_channel_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +multi_channels_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav' + + +class BackendTest(unittest.TestCase): + def setUp(self): + self.initWavInput() + + def initWavInput(self): + self.files = [] + for url in [mono_channel_wav, multi_channels_wav]: + if not os.path.isfile(os.path.basename(url)): + urllib.request.urlretrieve(url, os.path.basename(url)) + self.files.append(os.path.basename(url)) + + def initParmas(self): + raise NotImplementedError diff --git a/paddleaudio/tests/backends/soundfile/__init__.py b/paddleaudio/tests/backends/soundfile/__init__.py new file mode 100644 index 00000000000..97043fd7ba6 --- /dev/null +++ b/paddleaudio/tests/backends/soundfile/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddleaudio/tests/backends/soundfile/test_io.py b/paddleaudio/tests/backends/soundfile/test_io.py new file mode 100644 index 00000000000..0f7580a40d3 --- /dev/null +++ b/paddleaudio/tests/backends/soundfile/test_io.py @@ -0,0 +1,73 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import filecmp +import os +import unittest + +import numpy as np +import soundfile as sf + +import paddleaudio +from ..base import BackendTest + + +class TestIO(BackendTest): + def test_load_mono_channel(self): + sf_data, sf_sr = sf.read(self.files[0]) + pa_data, pa_sr = paddleaudio.load( + self.files[0], normal=False, dtype='float64') + + self.assertEqual(sf_data.dtype, pa_data.dtype) + self.assertEqual(sf_sr, pa_sr) + np.testing.assert_array_almost_equal(sf_data, pa_data) + + def test_load_multi_channels(self): + sf_data, sf_sr = sf.read(self.files[1]) + sf_data = sf_data.T # Channel dim first + pa_data, pa_sr = paddleaudio.load( + self.files[1], mono=False, normal=False, dtype='float64') + + self.assertEqual(sf_data.dtype, pa_data.dtype) + self.assertEqual(sf_sr, pa_sr) + np.testing.assert_array_almost_equal(sf_data, pa_data) + + def test_save_mono_channel(self): + waveform, sr = np.random.randint( + low=-32768, high=32768, size=(48000), dtype=np.int16), 16000 + sf_tmp_file = 'sf_tmp.wav' + pa_tmp_file = 'pa_tmp.wav' + + sf.write(sf_tmp_file, waveform, sr) + paddleaudio.save(waveform, sr, pa_tmp_file) + + self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file)) + for file in [sf_tmp_file, pa_tmp_file]: + os.remove(file) + + def test_save_multi_channels(self): + waveform, sr = np.random.randint( + low=-32768, high=32768, size=(2, 48000), dtype=np.int16), 16000 + sf_tmp_file = 'sf_tmp.wav' + pa_tmp_file = 'pa_tmp.wav' + + sf.write(sf_tmp_file, waveform.T, sr) + paddleaudio.save(waveform.T, sr, pa_tmp_file) + + self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file)) + for file in [sf_tmp_file, pa_tmp_file]: + os.remove(file) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/benchmark/README.md b/paddleaudio/tests/benchmark/README.md new file mode 100644 index 00000000000..b9034100d4b --- /dev/null +++ b/paddleaudio/tests/benchmark/README.md @@ -0,0 +1,39 @@ +# 1. Prepare +First, install `pytest-benchmark` via pip. +```sh +pip install pytest-benchmark +``` + +# 2. Run +Run the specific script for profiling. +```sh +pytest melspectrogram.py +``` + +Result: +```sh +========================================================================== test session starts ========================================================================== +platform linux -- Python 3.7.7, pytest-7.0.1, pluggy-1.0.0 +benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) +rootdir: /ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddleaudio +plugins: typeguard-2.12.1, benchmark-3.4.1, anyio-3.5.0 +collected 4 items + +melspectrogram.py .... [100%] + + +-------------------------------------------------------------------------------------------------- benchmark: 4 tests ------------------------------------------------------------------------------------------------- +Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +test_melspect_gpu_torchaudio 202.0765 (1.0) 360.6230 (1.0) 218.1168 (1.0) 16.3022 (1.0) 214.2871 (1.0) 21.8451 (1.0) 40;3 4,584.7001 (1.0) 286 1 +test_melspect_gpu 657.8509 (3.26) 908.0470 (2.52) 724.2545 (3.32) 106.5771 (6.54) 669.9096 (3.13) 113.4719 (5.19) 1;0 1,380.7300 (0.30) 5 1 +test_melspect_cpu_torchaudio 1,247.6053 (6.17) 2,892.5799 (8.02) 1,443.2853 (6.62) 345.3732 (21.19) 1,262.7263 (5.89) 221.6385 (10.15) 56;53 692.8637 (0.15) 399 1 +test_melspect_cpu 20,326.2549 (100.59) 20,607.8682 (57.15) 20,473.4125 (93.86) 63.8654 (3.92) 20,467.0429 (95.51) 68.4294 (3.13) 8;1 48.8438 (0.01) 29 1 +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Legend: + Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile. + OPS: Operations Per Second, computed as 1 / Mean +========================================================================== 4 passed in 21.12s =========================================================================== + +``` diff --git a/paddleaudio/tests/benchmark/log_melspectrogram.py b/paddleaudio/tests/benchmark/log_melspectrogram.py new file mode 100644 index 00000000000..5230acd424e --- /dev/null +++ b/paddleaudio/tests/benchmark/log_melspectrogram.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import urllib.request + +import librosa +import numpy as np +import paddle +import torch +import torchaudio + +import paddleaudio + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +if not os.path.isfile(os.path.basename(wav_url)): + urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) + +waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) +waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) +waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) + +# Feature conf +mel_conf = { + 'sr': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, +} + +mel_conf_torchaudio = { + 'sample_rate': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, + 'norm': 'slaney', + 'mel_scale': 'slaney', +} + + +def enable_cpu_device(): + paddle.set_device('cpu') + + +def enable_gpu_device(): + paddle.set_device('gpu') + + +log_mel_extractor = paddleaudio.features.LogMelSpectrogram( + **mel_conf, f_min=0.0, top_db=80.0, dtype=waveform_tensor.dtype) + + +def log_melspectrogram(): + return log_mel_extractor(waveform_tensor).squeeze(0) + + +def test_log_melspect_cpu(benchmark): + enable_cpu_device() + feature_paddleaudio = benchmark(log_melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_log_melspect_gpu(benchmark): + enable_gpu_device() + feature_paddleaudio = benchmark(log_melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=2) + + +mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram( + **mel_conf_torchaudio, f_min=0.0) +amplitude_to_DB = torchaudio.transforms.AmplitudeToDB('power', top_db=80.0) + + +def melspectrogram_torchaudio(): + return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0) + + +def log_melspectrogram_torchaudio(): + mel_specgram = mel_extractor_torchaudio(waveform_tensor_torch) + return amplitude_to_DB(mel_specgram).squeeze(0) + + +def test_log_melspect_cpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB + + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu') + waveform_tensor_torch = waveform_tensor_torch.to('cpu') + amplitude_to_DB = amplitude_to_DB.to('cpu') + + feature_paddleaudio = benchmark(log_melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_log_melspect_gpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB + + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda') + waveform_tensor_torch = waveform_tensor_torch.to('cuda') + amplitude_to_DB = amplitude_to_DB.to('cuda') + + feature_torchaudio = benchmark(log_melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0) + np.testing.assert_array_almost_equal( + feature_librosa, feature_torchaudio.cpu(), decimal=2) diff --git a/paddleaudio/tests/benchmark/melspectrogram.py b/paddleaudio/tests/benchmark/melspectrogram.py new file mode 100644 index 00000000000..e0b79b45a71 --- /dev/null +++ b/paddleaudio/tests/benchmark/melspectrogram.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import urllib.request + +import librosa +import numpy as np +import paddle +import torch +import torchaudio + +import paddleaudio + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +if not os.path.isfile(os.path.basename(wav_url)): + urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) + +waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) +waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) +waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) + +# Feature conf +mel_conf = { + 'sr': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, +} + +mel_conf_torchaudio = { + 'sample_rate': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, + 'norm': 'slaney', + 'mel_scale': 'slaney', +} + + +def enable_cpu_device(): + paddle.set_device('cpu') + + +def enable_gpu_device(): + paddle.set_device('gpu') + + +mel_extractor = paddleaudio.features.MelSpectrogram( + **mel_conf, f_min=0.0, dtype=waveform_tensor.dtype) + + +def melspectrogram(): + return mel_extractor(waveform_tensor).squeeze(0) + + +def test_melspect_cpu(benchmark): + enable_cpu_device() + feature_paddleaudio = benchmark(melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_melspect_gpu(benchmark): + enable_gpu_device() + feature_paddleaudio = benchmark(melspectrogram) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram( + **mel_conf_torchaudio, f_min=0.0) + + +def melspectrogram_torchaudio(): + return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0) + + +def test_melspect_cpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu') + waveform_tensor_torch = waveform_tensor_torch.to('cpu') + feature_paddleaudio = benchmark(melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_melspect_gpu_torchaudio(benchmark): + global waveform_tensor_torch, mel_extractor_torchaudio + mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda') + waveform_tensor_torch = waveform_tensor_torch.to('cuda') + feature_torchaudio = benchmark(melspectrogram_torchaudio) + feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_torchaudio.cpu(), decimal=3) diff --git a/paddleaudio/tests/benchmark/mfcc.py b/paddleaudio/tests/benchmark/mfcc.py new file mode 100644 index 00000000000..2572ff33dd1 --- /dev/null +++ b/paddleaudio/tests/benchmark/mfcc.py @@ -0,0 +1,122 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import urllib.request + +import librosa +import numpy as np +import paddle +import torch +import torchaudio + +import paddleaudio + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' +if not os.path.isfile(os.path.basename(wav_url)): + urllib.request.urlretrieve(wav_url, os.path.basename(wav_url)) + +waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url))) +waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0) +waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0) + +# Feature conf +mel_conf = { + 'sr': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, +} +mfcc_conf = { + 'n_mfcc': 20, + 'top_db': 80.0, +} +mfcc_conf.update(mel_conf) + +mel_conf_torchaudio = { + 'sample_rate': sr, + 'n_fft': 512, + 'hop_length': 128, + 'n_mels': 40, + 'norm': 'slaney', + 'mel_scale': 'slaney', +} +mfcc_conf_torchaudio = { + 'sample_rate': sr, + 'n_mfcc': 20, +} + + +def enable_cpu_device(): + paddle.set_device('cpu') + + +def enable_gpu_device(): + paddle.set_device('gpu') + + +mfcc_extractor = paddleaudio.features.MFCC( + **mfcc_conf, f_min=0.0, dtype=waveform_tensor.dtype) + + +def mfcc(): + return mfcc_extractor(waveform_tensor).squeeze(0) + + +def test_mfcc_cpu(benchmark): + enable_cpu_device() + feature_paddleaudio = benchmark(mfcc) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_mfcc_gpu(benchmark): + enable_gpu_device() + feature_paddleaudio = benchmark(mfcc) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +del mel_conf_torchaudio['sample_rate'] +mfcc_extractor_torchaudio = torchaudio.transforms.MFCC( + **mfcc_conf_torchaudio, melkwargs=mel_conf_torchaudio) + + +def mfcc_torchaudio(): + return mfcc_extractor_torchaudio(waveform_tensor_torch).squeeze(0) + + +def test_mfcc_cpu_torchaudio(benchmark): + global waveform_tensor_torch, mfcc_extractor_torchaudio + + mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cpu') + waveform_tensor_torch = waveform_tensor_torch.to('cpu') + + feature_paddleaudio = benchmark(mfcc_torchaudio) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddleaudio, decimal=3) + + +def test_mfcc_gpu_torchaudio(benchmark): + global waveform_tensor_torch, mfcc_extractor_torchaudio + + mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cuda') + waveform_tensor_torch = waveform_tensor_torch.to('cuda') + + feature_torchaudio = benchmark(mfcc_torchaudio) + feature_librosa = librosa.feature.mfcc(waveform, **mel_conf) + np.testing.assert_array_almost_equal( + feature_librosa, feature_torchaudio.cpu(), decimal=3) diff --git a/paddleaudio/tests/features/__init__.py b/paddleaudio/tests/features/__init__.py new file mode 100644 index 00000000000..97043fd7ba6 --- /dev/null +++ b/paddleaudio/tests/features/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddleaudio/tests/features/base.py b/paddleaudio/tests/features/base.py new file mode 100644 index 00000000000..725e1e2e70b --- /dev/null +++ b/paddleaudio/tests/features/base.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest +import urllib.request + +import numpy as np +import paddle + +from paddleaudio import load + +wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav' + + +class FeatTest(unittest.TestCase): + def setUp(self): + self.initParmas() + self.initWavInput() + self.setUpDevice() + + def setUpDevice(self, device='cpu'): + paddle.set_device(device) + + def initWavInput(self, url=wav_url): + if not os.path.isfile(os.path.basename(url)): + urllib.request.urlretrieve(url, os.path.basename(url)) + self.waveform, self.sr = load(os.path.abspath(os.path.basename(url))) + self.waveform = self.waveform.astype( + np.float32 + ) # paddlespeech.s2t.transform.spectrogram only supports float32 + dim = len(self.waveform.shape) + + assert dim in [1, 2] + if dim == 1: + self.waveform = np.expand_dims(self.waveform, 0) + + def initParmas(self): + raise NotImplementedError diff --git a/paddleaudio/tests/features/test_istft.py b/paddleaudio/tests/features/test_istft.py new file mode 100644 index 00000000000..23371200b62 --- /dev/null +++ b/paddleaudio/tests/features/test_istft.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +from .base import FeatTest +from paddleaudio.functional.window import get_window +from paddlespeech.s2t.transform.spectrogram import IStft +from paddlespeech.s2t.transform.spectrogram import Stft + + +class TestIstft(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + self.window_str = 'hann' + + def test_istft(self): + ps_stft = Stft(self.n_fft, self.hop_length) + ps_res = ps_stft( + self.waveform.T).squeeze(1).T # (n_fft//2 + 1, n_frmaes) + x = paddle.to_tensor(ps_res) + + ps_istft = IStft(self.hop_length) + ps_res = ps_istft(ps_res.T) + + window = get_window( + self.window_str, self.n_fft, dtype=self.waveform.dtype) + pd_res = paddle.signal.istft( + x, self.n_fft, self.hop_length, window=window) + + np.testing.assert_array_almost_equal(ps_res, pd_res, decimal=5) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_kaldi.py b/paddleaudio/tests/features/test_kaldi.py new file mode 100644 index 00000000000..6e826aaa75b --- /dev/null +++ b/paddleaudio/tests/features/test_kaldi.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle +import torch +import torchaudio + +import paddleaudio +from .base import FeatTest + + +class TestKaldi(FeatTest): + def initParmas(self): + self.window_size = 1024 + self.dtype = 'float32' + + def test_window(self): + t_hann_window = torch.hann_window( + self.window_size, periodic=False, dtype=eval(f'torch.{self.dtype}')) + t_hamm_window = torch.hamming_window( + self.window_size, + periodic=False, + alpha=0.54, + beta=0.46, + dtype=eval(f'torch.{self.dtype}')) + t_povey_window = torch.hann_window( + self.window_size, periodic=False, + dtype=eval(f'torch.{self.dtype}')).pow(0.85) + + p_hann_window = paddleaudio.functional.window.get_window( + 'hann', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')) + p_hamm_window = paddleaudio.functional.window.get_window( + 'hamming', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')) + p_povey_window = paddleaudio.functional.window.get_window( + 'hann', + self.window_size, + fftbins=False, + dtype=eval(f'paddle.{self.dtype}')).pow(0.85) + + np.testing.assert_array_almost_equal(t_hann_window, p_hann_window) + np.testing.assert_array_almost_equal(t_hamm_window, p_hamm_window) + np.testing.assert_array_almost_equal(t_povey_window, p_povey_window) + + def test_fbank(self): + ta_features = torchaudio.compliance.kaldi.fbank( + torch.from_numpy(self.waveform.astype(self.dtype))) + pa_features = paddleaudio.compliance.kaldi.fbank( + paddle.to_tensor(self.waveform.astype(self.dtype))) + np.testing.assert_array_almost_equal( + ta_features, pa_features, decimal=4) + + def test_mfcc(self): + ta_features = torchaudio.compliance.kaldi.mfcc( + torch.from_numpy(self.waveform.astype(self.dtype))) + pa_features = paddleaudio.compliance.kaldi.mfcc( + paddle.to_tensor(self.waveform.astype(self.dtype))) + np.testing.assert_array_almost_equal( + ta_features, pa_features, decimal=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_librosa.py b/paddleaudio/tests/features/test_librosa.py new file mode 100644 index 00000000000..cf0c98c7295 --- /dev/null +++ b/paddleaudio/tests/features/test_librosa.py @@ -0,0 +1,281 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import librosa +import numpy as np +import paddle + +import paddleaudio +from .base import FeatTest +from paddleaudio.functional.window import get_window + + +class TestLibrosa(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + self.n_mels = 40 + self.n_mfcc = 20 + self.fmin = 0.0 + self.window_str = 'hann' + self.pad_mode = 'reflect' + self.top_db = 80.0 + + def test_stft(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + feature_librosa = librosa.core.stft( + y=self.waveform, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=self.window_str, + center=True, + dtype=None, + pad_mode=self.pad_mode, ) + x = paddle.to_tensor(self.waveform).unsqueeze(0) + window = get_window(self.window_str, self.n_fft, dtype=x.dtype) + feature_paddle = paddle.signal.stft( + x=x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=window, + center=True, + pad_mode=self.pad_mode, + normalized=False, + onesided=True, ).squeeze(0) + + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddle, decimal=5) + + def test_istft(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # Get stft result from librosa. + stft_matrix = librosa.core.stft( + y=self.waveform, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=self.window_str, + center=True, + pad_mode=self.pad_mode, ) + + feature_librosa = librosa.core.istft( + stft_matrix=stft_matrix, + hop_length=self.hop_length, + win_length=None, + window=self.window_str, + center=True, + dtype=None, + length=None, ) + + x = paddle.to_tensor(stft_matrix).unsqueeze(0) + window = get_window( + self.window_str, + self.n_fft, + dtype=paddle.to_tensor(self.waveform).dtype) + feature_paddle = paddle.signal.istft( + x=x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=None, + window=window, + center=True, + normalized=False, + onesided=True, + length=None, + return_complex=False, ).squeeze(0) + + np.testing.assert_array_almost_equal( + feature_librosa, feature_paddle, decimal=5) + + def test_mel(self): + feature_librosa = librosa.filters.mel( + sr=self.sr, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.fmin, + fmax=None, + htk=False, + norm='slaney', + dtype=self.waveform.dtype, ) + feature_compliance = paddleaudio.compliance.librosa.compute_fbank_matrix( + sr=self.sr, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.fmin, + fmax=None, + htk=False, + norm='slaney', + dtype=self.waveform.dtype, ) + x = paddle.to_tensor(self.waveform) + feature_functional = paddleaudio.functional.compute_fbank_matrix( + sr=self.sr, + n_fft=self.n_fft, + n_mels=self.n_mels, + f_min=self.fmin, + f_max=None, + htk=False, + norm='slaney', + dtype=x.dtype, ) + + np.testing.assert_array_almost_equal(feature_librosa, + feature_compliance) + np.testing.assert_array_almost_equal(feature_librosa, + feature_functional) + + def test_melspect(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.melspectrogram( + y=self.waveform, + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + + # paddleaudio.compliance.librosa: + feature_compliance = paddleaudio.compliance.librosa.melspectrogram( + x=self.waveform, + sr=self.sr, + window_size=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin, + to_db=False) + + # paddleaudio.features.layer + x = paddle.to_tensor( + self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. + feature_extractor = paddleaudio.features.MelSpectrogram( + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=self.fmin, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal( + feature_librosa, feature_compliance, decimal=5) + np.testing.assert_array_almost_equal( + feature_librosa, feature_layer, decimal=5) + + def test_log_melspect(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.melspectrogram( + y=self.waveform, + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + feature_librosa = librosa.power_to_db(feature_librosa, top_db=None) + + # paddleaudio.compliance.librosa: + feature_compliance = paddleaudio.compliance.librosa.melspectrogram( + x=self.waveform, + sr=self.sr, + window_size=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + + # paddleaudio.features.layer + x = paddle.to_tensor( + self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. + feature_extractor = paddleaudio.features.LogMelSpectrogram( + sr=self.sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=self.fmin, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal( + feature_librosa, feature_compliance, decimal=5) + np.testing.assert_array_almost_equal( + feature_librosa, feature_layer, decimal=4) + + def test_mfcc(self): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.mfcc( + y=self.waveform, + sr=self.sr, + S=None, + n_mfcc=self.n_mfcc, + dct_type=2, + norm='ortho', + lifter=0, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin) + + # paddleaudio.compliance.librosa: + feature_compliance = paddleaudio.compliance.librosa.mfcc( + x=self.waveform, + sr=self.sr, + n_mfcc=self.n_mfcc, + dct_type=2, + norm='ortho', + lifter=0, + window_size=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + fmin=self.fmin, + top_db=self.top_db) + + # paddleaudio.features.layer + x = paddle.to_tensor( + self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim. + feature_extractor = paddleaudio.features.MFCC( + sr=self.sr, + n_mfcc=self.n_mfcc, + n_fft=self.n_fft, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=self.fmin, + top_db=self.top_db, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal( + feature_librosa, feature_compliance, decimal=4) + np.testing.assert_array_almost_equal( + feature_librosa, feature_layer, decimal=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_log_melspectrogram.py b/paddleaudio/tests/features/test_log_melspectrogram.py new file mode 100644 index 00000000000..6bae2df3f56 --- /dev/null +++ b/paddleaudio/tests/features/test_log_melspectrogram.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +import paddleaudio +from .base import FeatTest +from paddlespeech.s2t.transform.spectrogram import LogMelSpectrogram + + +class TestLogMelSpectrogram(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + self.n_mels = 40 + + def test_log_melspect(self): + ps_melspect = LogMelSpectrogram(self.sr, self.n_mels, self.n_fft, + self.hop_length) + ps_res = ps_melspect(self.waveform.T).squeeze(1).T + + x = paddle.to_tensor(self.waveform) + # paddlespeech.s2t的特征存在幅度谱和功率谱滥用的情况 + ps_melspect = paddleaudio.features.LogMelSpectrogram( + self.sr, + self.n_fft, + self.hop_length, + power=1.0, + n_mels=self.n_mels, + f_min=0.0) + pa_res = (ps_melspect(x) / 10.0).squeeze(0).numpy() + + np.testing.assert_array_almost_equal(ps_res, pa_res, decimal=5) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_spectrogram.py b/paddleaudio/tests/features/test_spectrogram.py new file mode 100644 index 00000000000..50b21403b4f --- /dev/null +++ b/paddleaudio/tests/features/test_spectrogram.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +import paddleaudio +from .base import FeatTest +from paddlespeech.s2t.transform.spectrogram import Spectrogram + + +class TestSpectrogram(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + + def test_spectrogram(self): + ps_spect = Spectrogram(self.n_fft, self.hop_length) + ps_res = ps_spect(self.waveform.T).squeeze(1).T # Magnitude + + x = paddle.to_tensor(self.waveform) + pa_spect = paddleaudio.features.Spectrogram( + self.n_fft, self.hop_length, power=1.0) + pa_res = pa_spect(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal(ps_res, pa_res, decimal=5) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddleaudio/tests/features/test_stft.py b/paddleaudio/tests/features/test_stft.py new file mode 100644 index 00000000000..c64b5ebe6b4 --- /dev/null +++ b/paddleaudio/tests/features/test_stft.py @@ -0,0 +1,44 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import paddle + +from .base import FeatTest +from paddleaudio.functional.window import get_window +from paddlespeech.s2t.transform.spectrogram import Stft + + +class TestStft(FeatTest): + def initParmas(self): + self.n_fft = 512 + self.hop_length = 128 + self.window_str = 'hann' + + def test_stft(self): + ps_stft = Stft(self.n_fft, self.hop_length) + ps_res = ps_stft( + self.waveform.T).squeeze(1).T # (n_fft//2 + 1, n_frmaes) + + x = paddle.to_tensor(self.waveform) + window = get_window(self.window_str, self.n_fft, dtype=x.dtype) + pd_res = paddle.signal.stft( + x, self.n_fft, self.hop_length, window=window).squeeze(0).numpy() + + np.testing.assert_array_almost_equal(ps_res, pd_res, decimal=5) + + +if __name__ == '__main__': + unittest.main() diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index ab5eee6e288..f56d8a579c5 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -193,7 +193,8 @@ def preprocess(self, audio_file: Union[str, os.PathLike]): sr=feat_conf['sample_rate'], mono=True, dtype='float32') - logger.info("Preprocessing audio_file:" + audio_file) + if isinstance(audio_file, (str, os.PathLike)): + logger.info("Preprocessing audio_file:" + audio_file) # Feature extraction feature_extractor = LogMelSpectrogram( diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index d77d27b03c1..064939a85da 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -178,7 +178,8 @@ def _is_job_input(self, input_: Union[str, os.PathLike]) -> bool: Returns: bool: return `True` for job input, `False` otherwise. """ - return input_ and os.path.isfile(input_) and input_.endswith('.job') + return input_ and os.path.isfile(input_) and (input_.endswith('.job') or + input_.endswith('.txt')) def _get_job_contents( self, job_input: os.PathLike) -> Dict[str, Union[str, os.PathLike]]: diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 8423dfa8d1c..78eae769bee 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -237,6 +237,30 @@ 'speech_stats': 'feats_stats.npy', }, + "hifigan_aishell3-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip', + 'md5': + '3bb49bc75032ed12f79c00c8cc79a09a', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + "hifigan_vctk-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip', + 'md5': + '7da8f88359bca2457e705d924cf27bd4', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, # wavernn "wavernn_csmsc-zh": { @@ -365,6 +389,8 @@ def __init__(self): 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc', + 'hifigan_aishell3', + 'hifigan_vctk', 'wavernn_csmsc', ], help='Choose vocoder type of tts task.') diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index d7dcc90c7ab..f7d64b9a95e 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -192,7 +192,7 @@ def __init__(self): try: cfg = yaml.load(file, Loader=yaml.FullLoader) self._data.update(cfg) - except: + except Exception as e: self.flush() @property diff --git a/paddlespeech/server/__init__.py b/paddlespeech/server/__init__.py index 384061ddae2..97722c0a0cb 100644 --- a/paddlespeech/server/__init__.py +++ b/paddlespeech/server/__init__.py @@ -18,6 +18,7 @@ from .base_commands import ServerBaseCommand from .base_commands import ServerHelpCommand from .bin.paddlespeech_client import ASRClientExecutor +from .bin.paddlespeech_client import CLSClientExecutor from .bin.paddlespeech_client import TTSClientExecutor from .bin.paddlespeech_server import ServerExecutor diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index ee6ab7ad764..40f17c63c8e 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -31,7 +31,7 @@ from paddlespeech.server.utils.audio_process import wav2pcm from paddlespeech.server.utils.util import wav2base64 -__all__ = ['TTSClientExecutor', 'ASRClientExecutor'] +__all__ = ['TTSClientExecutor', 'ASRClientExecutor', 'CLSClientExecutor'] @cli_client_register( @@ -70,13 +70,9 @@ def __init__(self): choices=[0, 8000, 16000], help='Sampling rate, the default is the same as the model') self.parser.add_argument( - '--output', - type=str, - default="./output.wav", - help='Synthesized audio file') + '--output', type=str, default=None, help='Synthesized audio file') - def postprocess(self, response_dict: dict, outfile: str) -> float: - wav_base64 = response_dict["result"]["audio"] + def postprocess(self, wav_base64: str, outfile: str) -> float: audio_data_byte = base64.b64decode(wav_base64) # from byte samples, sample_rate = soundfile.read( @@ -93,37 +89,38 @@ def postprocess(self, response_dict: dict, outfile: str) -> float: else: logger.error("The format for saving audio only supports wav or pcm") - duration = len(samples) / sample_rate - return duration - def execute(self, argv: List[str]) -> bool: args = self.parser.parse_args(argv) - try: - url = 'http://' + args.server_ip + ":" + str( - args.port) + '/paddlespeech/tts' - request = { - "text": args.input, - "spk_id": args.spk_id, - "speed": args.speed, - "volume": args.volume, - "sample_rate": args.sample_rate, - "save_path": args.output - } - st = time.time() - response = requests.post(url, json.dumps(request)) - time_consume = time.time() - st - - response_dict = response.json() - duration = self.postprocess(response_dict, args.output) + input_ = args.input + server_ip = args.server_ip + port = args.port + spk_id = args.spk_id + speed = args.speed + volume = args.volume + sample_rate = args.sample_rate + output = args.output + try: + time_start = time.time() + res = self( + input=input_, + server_ip=server_ip, + port=port, + spk_id=spk_id, + speed=speed, + volume=volume, + sample_rate=sample_rate, + output=output) + time_end = time.time() + time_consume = time_end - time_start + response_dict = res.json() logger.info(response_dict["message"]) - logger.info("Save synthesized audio successfully on %s." % - (args.output)) - logger.info("Audio duration: %f s." % (duration)) + logger.info("Save synthesized audio successfully on %s." % (output)) + logger.info("Audio duration: %f s." % + (response_dict['result']['duration'])) logger.info("Response time: %f s." % (time_consume)) - return True - except BaseException: + except Exception as e: logger.error("Failed to synthesized audio.") return False @@ -136,7 +133,7 @@ def __call__(self, speed: float=1.0, volume: float=1.0, sample_rate: int=0, - output: str="./output.wav"): + output: str=None): """ Python API to call an executor. """ @@ -151,20 +148,11 @@ def __call__(self, "save_path": output } - try: - st = time.time() - response = requests.post(url, json.dumps(request)) - time_consume = time.time() - st - response_dict = response.json() - duration = self.postprocess(response_dict, output) - - print(response_dict["message"]) - print("Save synthesized audio successfully on %s." % (output)) - print("Audio duration: %f s." % (duration)) - print("Response time: %f s." % (time_consume)) - print("RTF: %f " % (time_consume / duration)) - except BaseException: - print("Failed to synthesized audio.") + res = requests.post(url, json.dumps(request)) + response_dict = res.json() + if not output: + self.postprocess(response_dict["result"]["audio"], output) + return res @cli_client_register( @@ -193,24 +181,27 @@ def __init__(self): def execute(self, argv: List[str]) -> bool: args = self.parser.parse_args(argv) - url = 'http://' + args.server_ip + ":" + str( - args.port) + '/paddlespeech/asr' - audio = wav2base64(args.input) - data = { - "audio": audio, - "audio_format": args.audio_format, - "sample_rate": args.sample_rate, - "lang": args.lang, - } - time_start = time.time() + input_ = args.input + server_ip = args.server_ip + port = args.port + sample_rate = args.sample_rate + lang = args.lang + audio_format = args.audio_format + try: - r = requests.post(url=url, data=json.dumps(data)) - # ending Timestamp + time_start = time.time() + res = self( + input=input_, + server_ip=server_ip, + port=port, + sample_rate=sample_rate, + lang=lang, + audio_format=audio_format) time_end = time.time() - logger.info(r.json()) - logger.info("time cost %f s." % (time_end - time_start)) + logger.info(res.json()) + logger.info("Response time %f s." % (time_end - time_start)) return True - except BaseException: + except Exception as e: logger.error("Failed to speech recognition.") return False @@ -234,12 +225,65 @@ def __call__(self, "sample_rate": sample_rate, "lang": lang, } - time_start = time.time() + + res = requests.post(url=url, data=json.dumps(data)) + return res + + +@cli_client_register( + name='paddlespeech_client.cls', description='visit cls service') +class CLSClientExecutor(BaseExecutor): + def __init__(self): + super(CLSClientExecutor, self).__init__() + self.parser = argparse.ArgumentParser( + prog='paddlespeech_client.cls', add_help=True) + self.parser.add_argument( + '--server_ip', type=str, default='127.0.0.1', help='server ip') + self.parser.add_argument( + '--port', type=int, default=8090, help='server port') + self.parser.add_argument( + '--input', + type=str, + default=None, + help='Audio file to classify.', + required=True) + self.parser.add_argument( + '--topk', + type=int, + default=1, + help='Return topk scores of classification result.') + + def execute(self, argv: List[str]) -> bool: + args = self.parser.parse_args(argv) + input_ = args.input + server_ip = args.server_ip + port = args.port + topk = args.topk + try: - r = requests.post(url=url, data=json.dumps(data)) - # ending Timestamp + time_start = time.time() + res = self(input=input_, server_ip=server_ip, port=port, topk=topk) time_end = time.time() - print(r.json()) - print("time cost %f s." % (time_end - time_start)) - except BaseException: - print("Failed to speech recognition.") + logger.info(res.json()) + logger.info("Response time %f s." % (time_end - time_start)) + return True + except Exception as e: + logger.error("Failed to speech classification.") + return False + + @stats_wrapper + def __call__(self, + input: str, + server_ip: str="127.0.0.1", + port: int=8090, + topk: int=1): + """ + Python API to call an executor. + """ + + url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/cls' + audio = wav2base64(input) + data = {"audio": audio, "topk": topk} + + res = requests.post(url=url, data=json.dumps(data)) + return res diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index 3d71f091b3d..f6a7f429557 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -103,13 +103,14 @@ def __init__(self): '--task', type=str, default=None, - choices=['asr', 'tts'], + choices=['asr', 'tts', 'cls'], help='Choose speech task.', required=True) - self.task_choices = ['asr', 'tts'] + self.task_choices = ['asr', 'tts', 'cls'] self.model_name_format = { 'asr': 'Model-Language-Sample Rate', - 'tts': 'Model-Language' + 'tts': 'Model-Language', + 'cls': 'Model-Sample Rate' } def show_support_models(self, pretrained_models: dict): @@ -174,53 +175,24 @@ def execute(self, argv: List[str]) -> bool: ) return False - @stats_wrapper - def __call__( - self, - task: str=None, ): - """ - Python API to call an executor. - """ - self.task = task - if self.task not in self.task_choices: - print("Please input correct speech task, choices = ['asr', 'tts']") - - elif self.task == 'asr': - try: - from paddlespeech.cli.asr.infer import pretrained_models - print( - "Here is the table of ASR pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - # show ASR static pretrained model - from paddlespeech.server.engine.asr.paddleinference.asr_engine import pretrained_models - print( - "Here is the table of ASR static pretrained models supported in the service." - ) - self.show_support_models(pretrained_models) - - except BaseException: - print( - "Failed to get the table of ASR pretrained models supported in the service." - ) - - elif self.task == 'tts': + elif self.task == 'cls': try: - from paddlespeech.cli.tts.infer import pretrained_models - print( - "Here is the table of TTS pretrained models supported in the service." + from paddlespeech.cli.cls.infer import pretrained_models + logger.info( + "Here is the table of CLS pretrained models supported in the service." ) self.show_support_models(pretrained_models) - # show TTS static pretrained model - from paddlespeech.server.engine.tts.paddleinference.tts_engine import pretrained_models - print( - "Here is the table of TTS static pretrained models supported in the service." + # show CLS static pretrained model + from paddlespeech.server.engine.cls.paddleinference.cls_engine import pretrained_models + logger.info( + "Here is the table of CLS static pretrained models supported in the service." ) self.show_support_models(pretrained_models) + return True except BaseException: - print( - "Failed to get the table of TTS pretrained models supported in the service." + logger.error( + "Failed to get the table of CLS pretrained models supported in the service." ) + return False diff --git a/paddlespeech/server/conf/application.yaml b/paddlespeech/server/conf/application.yaml index 6048450b7ba..2b1a0599808 100644 --- a/paddlespeech/server/conf/application.yaml +++ b/paddlespeech/server/conf/application.yaml @@ -9,12 +9,14 @@ port: 8090 # The task format in the engin_list is: _ # task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference'] -engine_list: ['asr_python', 'tts_python'] +engine_list: ['asr_python', 'tts_python', 'cls_python'] ################################################################################# # ENGINE CONFIG # ################################################################################# + +################################### ASR ######################################### ################### speech task: asr; engine_type: python ####################### asr_python: model: 'conformer_wenetspeech' @@ -46,6 +48,7 @@ asr_inference: summary: True # False -> do not show predictor config +################################### TTS ######################################### ################### speech task: tts; engine_type: python ####################### tts_python: # am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc', @@ -105,3 +108,30 @@ tts_inference: # others lang: 'zh' + +################################### CLS ######################################### +################### speech task: cls; engine_type: python ####################### +cls_python: + # model choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6'] + model: 'panns_cnn14' + cfg_path: # [optional] Config of cls task. + ckpt_path: # [optional] Checkpoint file of model. + label_file: # [optional] Label file of cls task. + device: # set 'gpu:id' or 'cpu' + + +################### speech task: cls; engine_type: inference ####################### +cls_inference: + # model_type choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6'] + model_type: 'panns_cnn14' + cfg_path: + model_path: # the pdmodel file of am static model [optional] + params_path: # the pdiparams file of am static model [optional] + label_file: # [optional] Label file of cls task. + + predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + diff --git a/paddlespeech/server/engine/cls/__init__.py b/paddlespeech/server/engine/cls/__init__.py new file mode 100644 index 00000000000..97043fd7ba6 --- /dev/null +++ b/paddlespeech/server/engine/cls/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/cls/paddleinference/__init__.py b/paddlespeech/server/engine/cls/paddleinference/__init__.py new file mode 100644 index 00000000000..97043fd7ba6 --- /dev/null +++ b/paddlespeech/server/engine/cls/paddleinference/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/cls/paddleinference/cls_engine.py b/paddlespeech/server/engine/cls/paddleinference/cls_engine.py new file mode 100644 index 00000000000..3982effd902 --- /dev/null +++ b/paddlespeech/server/engine/cls/paddleinference/cls_engine.py @@ -0,0 +1,224 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os +import time +from typing import Optional + +import numpy as np +import paddle +import yaml + +from paddlespeech.cli.cls.infer import CLSExecutor +from paddlespeech.cli.log import logger +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.server.engine.base_engine import BaseEngine +from paddlespeech.server.utils.paddle_predictor import init_predictor +from paddlespeech.server.utils.paddle_predictor import run_model + +__all__ = ['CLSEngine'] + +pretrained_models = { + "panns_cnn6-32k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz', + 'md5': + 'da087c31046d23281d8ec5188c1967da', + 'cfg_path': + 'panns.yaml', + 'model_path': + 'inference.pdmodel', + 'params_path': + 'inference.pdiparams', + 'label_file': + 'audioset_labels.txt', + }, + "panns_cnn10-32k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz', + 'md5': + '5460cc6eafbfaf0f261cc75b90284ae1', + 'cfg_path': + 'panns.yaml', + 'model_path': + 'inference.pdmodel', + 'params_path': + 'inference.pdiparams', + 'label_file': + 'audioset_labels.txt', + }, + "panns_cnn14-32k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz', + 'md5': + 'ccc80b194821274da79466862b2ab00f', + 'cfg_path': + 'panns.yaml', + 'model_path': + 'inference.pdmodel', + 'params_path': + 'inference.pdiparams', + 'label_file': + 'audioset_labels.txt', + }, +} + + +class CLSServerExecutor(CLSExecutor): + def __init__(self): + super().__init__() + pass + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(pretrained_models.keys()) + assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path + + def _init_from_path( + self, + model_type: str='panns_cnn14', + cfg_path: Optional[os.PathLike]=None, + model_path: Optional[os.PathLike]=None, + params_path: Optional[os.PathLike]=None, + label_file: Optional[os.PathLike]=None, + predictor_conf: dict=None, ): + """ + Init model and other resources from a specific path. + """ + + if cfg_path is None or model_path is None or params_path is None or label_file is None: + tag = model_type + '-' + '32k' + self.res_path = self._get_pretrained_path(tag) + self.cfg_path = os.path.join(self.res_path, + pretrained_models[tag]['cfg_path']) + self.model_path = os.path.join(self.res_path, + pretrained_models[tag]['model_path']) + self.params_path = os.path.join( + self.res_path, pretrained_models[tag]['params_path']) + self.label_file = os.path.join(self.res_path, + pretrained_models[tag]['label_file']) + else: + self.cfg_path = os.path.abspath(cfg_path) + self.model_path = os.path.abspath(model_path) + self.params_path = os.path.abspath(params_path) + self.label_file = os.path.abspath(label_file) + + logger.info(self.cfg_path) + logger.info(self.model_path) + logger.info(self.params_path) + logger.info(self.label_file) + + # config + with open(self.cfg_path, 'r') as f: + self._conf = yaml.safe_load(f) + logger.info("Read cfg file successfully.") + + # labels + self._label_list = [] + with open(self.label_file, 'r') as f: + for line in f: + self._label_list.append(line.strip()) + logger.info("Read label file successfully.") + + # Create predictor + self.predictor_conf = predictor_conf + self.predictor = init_predictor( + model_file=self.model_path, + params_file=self.params_path, + predictor_conf=self.predictor_conf) + logger.info("Create predictor successfully.") + + @paddle.no_grad() + def infer(self): + """ + Model inference and result stored in self.output. + """ + output = run_model(self.predictor, [self._inputs['feats'].numpy()]) + self._outputs['logits'] = output[0] + + +class CLSEngine(BaseEngine): + """CLS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self): + super(CLSEngine, self).__init__() + + def init(self, config: dict) -> bool: + """init engine resource + + Args: + config_file (str): config file + + Returns: + bool: init failed or success + """ + self.executor = CLSServerExecutor() + self.config = config + self.executor._init_from_path( + self.config.model_type, self.config.cfg_path, + self.config.model_path, self.config.params_path, + self.config.label_file, self.config.predictor_conf) + + logger.info("Initialize CLS server engine successfully.") + return True + + def run(self, audio_data): + """engine run + + Args: + audio_data (bytes): base64.b64decode + """ + + self.executor.preprocess(io.BytesIO(audio_data)) + st = time.time() + self.executor.infer() + infer_time = time.time() - st + + logger.info("inference time: {}".format(infer_time)) + logger.info("cls engine type: inference") + + def postprocess(self, topk: int): + """postprocess + """ + assert topk <= len(self.executor._label_list + ), 'Value of topk is larger than number of labels.' + + result = np.squeeze(self.executor._outputs['logits'], axis=0) + topk_idx = (-result).argsort()[:topk] + topk_results = [] + for idx in topk_idx: + res = {} + label, score = self.executor._label_list[idx], result[idx] + res['class_name'] = label + res['prob'] = score + topk_results.append(res) + + return topk_results diff --git a/paddlespeech/server/engine/cls/python/__init__.py b/paddlespeech/server/engine/cls/python/__init__.py new file mode 100644 index 00000000000..97043fd7ba6 --- /dev/null +++ b/paddlespeech/server/engine/cls/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/engine/cls/python/cls_engine.py b/paddlespeech/server/engine/cls/python/cls_engine.py new file mode 100644 index 00000000000..1a975b0a05b --- /dev/null +++ b/paddlespeech/server/engine/cls/python/cls_engine.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import time +from typing import List + +import paddle + +from paddlespeech.cli.cls.infer import CLSExecutor +from paddlespeech.cli.log import logger +from paddlespeech.server.engine.base_engine import BaseEngine + +__all__ = ['CLSEngine'] + + +class CLSServerExecutor(CLSExecutor): + def __init__(self): + super().__init__() + pass + + def get_topk_results(self, topk: int) -> List: + assert topk <= len( + self._label_list), 'Value of topk is larger than number of labels.' + + result = self._outputs['logits'].squeeze(0).numpy() + topk_idx = (-result).argsort()[:topk] + res = {} + topk_results = [] + for idx in topk_idx: + label, score = self._label_list[idx], result[idx] + res['class'] = label + res['prob'] = score + topk_results.append(res) + return topk_results + + +class CLSEngine(BaseEngine): + """CLS server engine + + Args: + metaclass: Defaults to Singleton. + """ + + def __init__(self): + super(CLSEngine, self).__init__() + + def init(self, config: dict) -> bool: + """init engine resource + + Args: + config_file (str): config file + + Returns: + bool: init failed or success + """ + self.input = None + self.output = None + self.executor = CLSServerExecutor() + self.config = config + try: + if self.config.device: + self.device = self.config.device + else: + self.device = paddle.get_device() + paddle.set_device(self.device) + except BaseException: + logger.error( + "Set device failed, please check if device is already used and the parameter 'device' in the yaml file" + ) + + try: + self.executor._init_from_path( + self.config.model, self.config.cfg_path, self.config.ckpt_path, + self.config.label_file) + except BaseException: + logger.error("Initialize CLS server engine Failed.") + return False + + logger.info("Initialize CLS server engine successfully on device: %s." % + (self.device)) + return True + + def run(self, audio_data): + """engine run + + Args: + audio_data (bytes): base64.b64decode + """ + self.executor.preprocess(io.BytesIO(audio_data)) + st = time.time() + self.executor.infer() + infer_time = time.time() - st + + logger.info("inference time: {}".format(infer_time)) + logger.info("cls engine type: python") + + def postprocess(self, topk: int): + """postprocess + """ + assert topk <= len(self.executor._label_list + ), 'Value of topk is larger than number of labels.' + + result = self.executor._outputs['logits'].squeeze(0).numpy() + topk_idx = (-result).argsort()[:topk] + topk_results = [] + for idx in topk_idx: + res = {} + label, score = self.executor._label_list[idx], result[idx] + res['class_name'] = label + res['prob'] = score + topk_results.append(res) + + return topk_results diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 546541edfcf..c39c44cae5f 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -31,5 +31,11 @@ def get_engine(engine_name: Text, engine_type: Text): elif engine_name == 'tts' and engine_type == 'python': from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine return TTSEngine() + elif engine_name == 'cls' and engine_type == 'inference': + from paddlespeech.server.engine.cls.paddleinference.cls_engine import CLSEngine + return CLSEngine() + elif engine_name == 'cls' and engine_type == 'python': + from paddlespeech.server.engine.cls.python.cls_engine import CLSEngine + return CLSEngine() else: return None diff --git a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py index 1bbbe0ea3e1..db8813ba901 100644 --- a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py +++ b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py @@ -250,27 +250,21 @@ def _init_from_path( self.frontend = English(phone_vocab_path=self.phones_dict) logger.info("frontend done!") - try: - # am predictor - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - logger.info("Create AM predictor successfully.") - except BaseException: - logger.error("Failed to create AM predictor.") - - try: - # voc predictor - self.voc_predictor_conf = voc_predictor_conf - self.voc_predictor = init_predictor( - model_file=self.voc_model, - params_file=self.voc_params, - predictor_conf=self.voc_predictor_conf) - logger.info("Create Vocoder predictor successfully.") - except BaseException: - logger.error("Failed to create Vocoder predictor.") + # Create am predictor + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + logger.info("Create AM predictor successfully.") + + # Create voc predictor + self.voc_predictor_conf = voc_predictor_conf + self.voc_predictor = init_predictor( + model_file=self.voc_model, + params_file=self.voc_params, + predictor_conf=self.voc_predictor_conf) + logger.info("Create Vocoder predictor successfully.") @paddle.no_grad() def infer(self, @@ -359,27 +353,22 @@ def __init__(self): def init(self, config: dict) -> bool: self.executor = TTSServerExecutor() - try: - self.config = config - self.executor._init_from_path( - am=self.config.am, - am_model=self.config.am_model, - am_params=self.config.am_params, - am_sample_rate=self.config.am_sample_rate, - phones_dict=self.config.phones_dict, - tones_dict=self.config.tones_dict, - speaker_dict=self.config.speaker_dict, - voc=self.config.voc, - voc_model=self.config.voc_model, - voc_params=self.config.voc_params, - voc_sample_rate=self.config.voc_sample_rate, - lang=self.config.lang, - am_predictor_conf=self.config.am_predictor_conf, - voc_predictor_conf=self.config.voc_predictor_conf, ) - - except BaseException: - logger.error("Initialize TTS server engine Failed.") - return False + self.config = config + self.executor._init_from_path( + am=self.config.am, + am_model=self.config.am_model, + am_params=self.config.am_params, + am_sample_rate=self.config.am_sample_rate, + phones_dict=self.config.phones_dict, + tones_dict=self.config.tones_dict, + speaker_dict=self.config.speaker_dict, + voc=self.config.voc, + voc_model=self.config.voc_model, + voc_params=self.config.voc_params, + voc_sample_rate=self.config.voc_sample_rate, + lang=self.config.lang, + am_predictor_conf=self.config.am_predictor_conf, + voc_predictor_conf=self.config.voc_predictor_conf, ) logger.info("Initialize TTS server engine successfully.") return True @@ -542,4 +531,4 @@ def run(self, postprocess_time)) logger.info("RTF: {}".format(rtf)) - return lang, target_sample_rate, wav_base64 + return lang, target_sample_rate, duration, wav_base64 diff --git a/paddlespeech/server/engine/tts/python/tts_engine.py b/paddlespeech/server/engine/tts/python/tts_engine.py index 8d6c7fd17e5..f153f60b966 100644 --- a/paddlespeech/server/engine/tts/python/tts_engine.py +++ b/paddlespeech/server/engine/tts/python/tts_engine.py @@ -250,4 +250,4 @@ def run(self, logger.info("RTF: {}".format(rtf)) logger.info("device: {}".format(self.device)) - return lang, target_sample_rate, wav_base64 + return lang, target_sample_rate, duration, wav_base64 diff --git a/paddlespeech/server/restful/api.py b/paddlespeech/server/restful/api.py index 2d69dee8739..3f91a03b647 100644 --- a/paddlespeech/server/restful/api.py +++ b/paddlespeech/server/restful/api.py @@ -16,6 +16,7 @@ from fastapi import APIRouter from paddlespeech.server.restful.asr_api import router as asr_router +from paddlespeech.server.restful.cls_api import router as cls_router from paddlespeech.server.restful.tts_api import router as tts_router _router = APIRouter() @@ -25,7 +26,7 @@ def setup_router(api_list: List): """setup router for fastapi Args: - api_list (List): [asr, tts] + api_list (List): [asr, tts, cls] Returns: APIRouter @@ -35,6 +36,8 @@ def setup_router(api_list: List): _router.include_router(asr_router) elif api_name == 'tts': _router.include_router(tts_router) + elif api_name == 'cls': + _router.include_router(cls_router) else: pass diff --git a/paddlespeech/server/restful/cls_api.py b/paddlespeech/server/restful/cls_api.py new file mode 100644 index 00000000000..306d9ca9c11 --- /dev/null +++ b/paddlespeech/server/restful/cls_api.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import traceback +from typing import Union + +from fastapi import APIRouter + +from paddlespeech.server.engine.engine_pool import get_engine_pool +from paddlespeech.server.restful.request import CLSRequest +from paddlespeech.server.restful.response import CLSResponse +from paddlespeech.server.restful.response import ErrorResponse +from paddlespeech.server.utils.errors import ErrorCode +from paddlespeech.server.utils.errors import failed_response +from paddlespeech.server.utils.exception import ServerBaseException + +router = APIRouter() + + +@router.get('/paddlespeech/cls/help') +def help(): + """help + + Returns: + json: [description] + """ + response = { + "success": "True", + "code": 200, + "message": { + "global": "success" + }, + "result": { + "description": "cls server", + "input": "base64 string of wavfile", + "output": "classification result" + } + } + return response + + +@router.post( + "/paddlespeech/cls", response_model=Union[CLSResponse, ErrorResponse]) +def cls(request_body: CLSRequest): + """cls api + + Args: + request_body (CLSRequest): [description] + + Returns: + json: [description] + """ + try: + audio_data = base64.b64decode(request_body.audio) + + # get single engine from engine pool + engine_pool = get_engine_pool() + cls_engine = engine_pool['cls'] + + cls_engine.run(audio_data) + cls_results = cls_engine.postprocess(request_body.topk) + + response = { + "success": True, + "code": 200, + "message": { + "description": "success" + }, + "result": { + "topk": request_body.topk, + "results": cls_results + } + } + + except ServerBaseException as e: + response = failed_response(e.error_code, e.msg) + except BaseException: + response = failed_response(ErrorCode.SERVER_UNKOWN_ERR) + traceback.print_exc() + + return response diff --git a/paddlespeech/server/restful/request.py b/paddlespeech/server/restful/request.py index 28908801977..dbac9dac881 100644 --- a/paddlespeech/server/restful/request.py +++ b/paddlespeech/server/restful/request.py @@ -15,7 +15,7 @@ from pydantic import BaseModel -__all__ = ['ASRRequest', 'TTSRequest'] +__all__ = ['ASRRequest', 'TTSRequest', 'CLSRequest'] #****************************************************************************************/ @@ -63,3 +63,18 @@ class TTSRequest(BaseModel): volume: float = 1.0 sample_rate: int = 0 save_path: str = None + + +#****************************************************************************************/ +#************************************ CLS request ***************************************/ +#****************************************************************************************/ +class CLSRequest(BaseModel): + """ + request body example + { + "audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...", + "topk": 1 + } + """ + audio: str + topk: int = 1 diff --git a/paddlespeech/server/restful/response.py b/paddlespeech/server/restful/response.py index 4e18ee0d790..a2a207e4f68 100644 --- a/paddlespeech/server/restful/response.py +++ b/paddlespeech/server/restful/response.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + from pydantic import BaseModel -__all__ = ['ASRResponse', 'TTSResponse'] +__all__ = ['ASRResponse', 'TTSResponse', 'CLSResponse'] class Message(BaseModel): @@ -52,10 +54,11 @@ class ASRResponse(BaseModel): #****************************************************************************************/ class TTSResult(BaseModel): lang: str = "zh" - sample_rate: int spk_id: int = 0 speed: float = 1.0 volume: float = 1.0 + sample_rate: int + duration: float save_path: str = None audio: str @@ -71,9 +74,11 @@ class TTSResponse(BaseModel): }, "result": { "lang": "zh", - "sample_rate": 24000, + "spk_id": 0, "speed": 1.0, "volume": 1.0, + "sample_rate": 24000, + "duration": 3.6125, "audio": "LTI1OTIuNjI1OTUwMzQsOTk2OS41NDk4...", "save_path": "./tts.wav" } @@ -85,6 +90,45 @@ class TTSResponse(BaseModel): result: TTSResult +#****************************************************************************************/ +#************************************ CLS response **************************************/ +#****************************************************************************************/ +class CLSResults(BaseModel): + class_name: str + prob: float + + +class CLSResult(BaseModel): + topk: int + results: List[CLSResults] + + +class CLSResponse(BaseModel): + """ + response example + { + "success": true, + "code": 0, + "message": { + "description": "success" + }, + "result": { + topk: 1 + results: [ + { + "class":"Speech", + "prob": 0.9027184844017029 + } + ] + } + } + """ + success: bool + code: int + message: Message + result: CLSResult + + #****************************************************************************************/ #********************************** Error response **************************************/ #****************************************************************************************/ diff --git a/paddlespeech/server/restful/tts_api.py b/paddlespeech/server/restful/tts_api.py index 0af0f6d0790..4e9bbe23ed3 100644 --- a/paddlespeech/server/restful/tts_api.py +++ b/paddlespeech/server/restful/tts_api.py @@ -98,7 +98,7 @@ def tts(request_body: TTSRequest): tts_engine = engine_pool['tts'] logger.info("Get tts engine successfully.") - lang, target_sample_rate, wav_base64 = tts_engine.run( + lang, target_sample_rate, duration, wav_base64 = tts_engine.run( text, spk_id, speed, volume, sample_rate, save_path) response = { @@ -113,6 +113,7 @@ def tts(request_body: TTSRequest): "speed": speed, "volume": volume, "sample_rate": target_sample_rate, + "duration": duration, "save_path": save_path, "audio": wav_base64 } diff --git a/paddlespeech/server/utils/paddle_predictor.py b/paddlespeech/server/utils/paddle_predictor.py index 4035d48d8c9..16653cf372e 100644 --- a/paddlespeech/server/utils/paddle_predictor.py +++ b/paddlespeech/server/utils/paddle_predictor.py @@ -35,10 +35,12 @@ def init_predictor(model_dir: Optional[os.PathLike]=None, Returns: predictor (PaddleInferPredictor): created predictor """ - if model_dir is not None: + assert os.path.isdir(model_dir), 'Please check model dir.' config = Config(args.model_dir) else: + assert os.path.isfile(model_file) and os.path.isfile( + params_file), 'Please check model and parameter files.' config = Config(model_file, params_file) # set device @@ -66,7 +68,6 @@ def init_predictor(model_dir: Optional[os.PathLike]=None, config.enable_memory_optim() predictor = create_predictor(config) - return predictor @@ -84,10 +85,8 @@ def run_model(predictor, input: List) -> List: for i, name in enumerate(input_names): input_handle = predictor.get_input_handle(name) input_handle.copy_from_cpu(input[i]) - # do the inference predictor.run() - results = [] # get out data from output tensor output_names = predictor.get_output_names() diff --git a/paddlespeech/t2s/exps/csmsc_test.txt b/paddlespeech/t2s/exps/csmsc_test.txt new file mode 100644 index 00000000000..d8cf367cd0c --- /dev/null +++ b/paddlespeech/t2s/exps/csmsc_test.txt @@ -0,0 +1,100 @@ +009901 昨日,这名伤者与医生全部被警方依法刑事拘留。 +009902 钱伟长想到上海来办学校是经过深思熟虑的。 +009903 她见我一进门就骂,吃饭时也骂,骂得我抬不起头。 +009904 李述德在离开之前,只说了一句柱驼杀父亲了。 +009905 这种车票和保险单捆绑出售属于重复性购买。 +009906 戴佩妮的男友西米露接唱情歌,让她非常开心。 +009907 观大势,谋大局,出大策始终是该院的办院方针。 +009908 他们骑着摩托回家,正好为农忙时的父母帮忙。 +009909 但是因为还没到退休年龄,只能掰着指头捱日子。 +009910 这几天雨水不断,人们恨不得待在家里不出门。 +009911 没想到徐赟,张海翔两人就此玩起了人间蒸发。 +009912 藤村此番发言可能是为了凸显野田的领导能力。 +009913 程长庚,生在清王朝嘉庆年间,安徽的潜山小县。 +009914 南海海域综合补给基地码头项目正在论证中。 +009915 也就是说今晚成都市民极有可能再次看到飘雪。 +009916 随着天气转热,各地的游泳场所开始人头攒动。 +009917 更让徐先生纳闷的是,房客的手机也打不通了。 +009918 遇到颠簸时,应听从乘务员的安全指令,回座位坐好。 +009919 他在后面呆惯了,怕自己一插身后的人会不满,不敢排进去。 +009920 傍晚七个小人回来了,白雪公主说,你们就是我命中的七个小矮人吧。 +009921 他本想说,教育局管这个,他们是一路的,这样一管岂不是妓女起嫖客? +009922 一种表示商品所有权的财物证券,也称商品证券,如提货单,交货单。 +009923 会有很丰富的东西留下来,说都说不完。 +009924 这句话像从天而降,吓得四周一片寂静。 +009925 记者所在的是受害人家属所在的右区。 +009926 不管哈大爷去哪,它都一步不离地跟着。 +009927 大家抬头望去,一只老鼠正趴在吊顶上。 +009928 我决定过年就辞职,接手我爸的废品站! +009929 最终,中国男子乒乓球队获得此奖项。 +009930 防汛抗旱两手抓,抗旱相对抓的不够。 +009931 图们江下游地区开发开放的进展如何? +009932 这要求中国必须有一个坚强的政党领导。 +009933 再说,关于利益上的事俺俩都不好开口。 +009934 明代瓦剌,鞑靼入侵明境也是通过此地。 +009935 咪咪舔着孩子,把它身上的毛舔干净。 +009936 是否这次的国标修订被大企业绑架了? +009937 判决后,姚某妻子胡某不服,提起上诉。 +009938 由此可以看出邯钢的经济效益来自何处。 +009939 琳达说,是瑜伽改变了她和马儿的生活。 +009940 楼下的保安告诉记者,这里不租也不卖。 +009941 习近平说,中斯两国人民传统友谊深厚。 +009942 传闻越来越多,后来连老汉儿自己都怕了。 +009943 我怒吼一声冲上去,举起砖头砸了过去。 +009944 我现在还不会,这就回去问问发明我的人。 +009945 显然,洛阳性奴案不具备上述两个前提。 +009946 另外,杰克逊有文唇线,眼线,眉毛的动作。 +009947 昨晚,华西都市报记者电话采访了尹琪。 +009948 涅拉季科未透露这些航空公司的名称。 +009949 从运行轨迹上来说,它也不可能是星星。 +009950 目前看,如果继续加息也存在两难问题。 +009951 曾宝仪在节目录制现场大爆观众糗事。 +009952 但任凭周某怎么叫,男子仍酣睡不醒。 +009953 老大爷说,小子,你挡我财路了,知道不? +009954 没料到,闯下大头佛的阿伟还不知悔改。 +009955 卡扎菲部落式统治已遭遇部落内讧。 +009956 这个孩子的生命一半来源于另一位女士捐赠的冷冻卵子。 +009957 出现这种泥鳅内阁的局面既是野田有意为之,也实属无奈。 +009958 济青高速济南,华山,章丘,邹平,周村,淄博,临淄站。 +009959 赵凌飞的话,反映了沈阳赛区所有奥运志愿者的共同心声。 +009960 因为,我们所发出的力量必会因难度加大而减弱。 +009961 发生事故的楼梯拐角处仍可看到血迹。 +009962 想过进公安,可能身高不够,老汉儿也不让我进去。 +009963 路上关卡很多,为了方便撤离,只好轻装前进。 +009964 原来比尔盖茨就是美国微软公司联合创始人呀。 +009965 之后他们一家三口将与双方父母往峇里岛旅游。 +009966 谢谢总理,也感谢广大网友的参与,我们明年再见。 +009967 事实上是,从来没有一个欺善怕恶的人能作出过稍大一点的成就。 +009968 我会打开邮件,你可以从那里继续。 +009969 美方对近期东海局势表示关切。 +009970 据悉,奥巴马一家人对这座冬季白宫极为满意。 +009971 打扫完你会很有成就感的,试一试,你就信了。 +009972 诺曼站在滑板车上,各就各位,准备出发啦! +009973 塔河的寒夜,气温降到了零下三十多摄氏度。 +009974 其间,连破六点六,六点五,六点四,六点三五等多个重要关口。 +009975 算命其实只是人们的一种自我安慰和自我暗示而已,我们还是要相信科学才好。 +009976 这一切都令人欢欣鼓舞,阿讷西没理由不坚持到最后。 +009977 直至公元前一万一千年,它又再次出现。 +009978 尽量少玩电脑,少看电视,少打游戏。 +009979 从五到七,前后也就是六个月的时间。 +009980 一进咖啡店,他就遇见一张熟悉的脸。 +009981 好在众弟兄看到了把她追了回来。 +009982 有一个人说,哥们儿我们跑过它才能活。 +009983 捅了她以后,模糊记得她没咋动了。 +009984 从小到大,葛启义没有收到过压岁钱。 +009985 舞台下的你会对舞台上的你说什么? +009986 但考生普遍认为,试题的怪多过难。 +009987 我希望每个人都能够尊重我们的隐私。 +009988 漫天的红霞使劲给两人增添气氛。 +009989 晚上加完班开车回家,太累了,迷迷糊糊开着车,走一半的时候,铛一声! +009990 该车将三人撞倒后,在大雾中逃窜。 +009991 这人一哆嗦,方向盘也把不稳了,差点撞上了高速边道护栏。 +009992 那女孩儿委屈的说,我一回头见你已经进去了我不敢进去啊! +009993 小明摇摇头说,不是,我只是美女看多了,想换个口味而已。 +009994 接下来,红娘要求记者交费,记者表示不知表姐身份证号码。 +009995 李东蓊表示,自己当时在法庭上发表了一次独特的公诉意见。 +009996 另一男子扑了上来,手里拿着明晃晃的长刀,向他胸口直刺。 +009997 今天,快递员拿着一个快递在办公室喊,秦王是哪个,有他快递? +009998 这场抗议活动究竟是如何发展演变的,又究竟是谁伤害了谁? +009999 因华国锋肖鸡,墓地设计根据其属相设计。 +010000 在狱中,张明宝悔恨交加,写了一份忏悔书。 diff --git a/paddlespeech/t2s/exps/gan_vocoder/synthesize.py b/paddlespeech/t2s/exps/gan_vocoder/synthesize.py index c60b9add2eb..9d9a8c49b68 100644 --- a/paddlespeech/t2s/exps/gan_vocoder/synthesize.py +++ b/paddlespeech/t2s/exps/gan_vocoder/synthesize.py @@ -34,7 +34,7 @@ def main(): "--generator-type", type=str, default="pwgan", - help="type of GANVocoder, should in {pwgan, mb_melgan, style_melgan, } now" + help="type of GANVocoder, should in {pwgan, mb_melgan, style_melgan, hifigan, } now" ) parser.add_argument("--config", type=str, help="GANVocoder config file.") parser.add_argument("--checkpoint", type=str, help="snapshot to load.") diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 26d7e2c089c..1188ddfb132 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -17,13 +17,92 @@ import numpy import soundfile as sf from paddle import inference - -from paddlespeech.t2s.frontend import English -from paddlespeech.t2s.frontend.zh_frontend import Frontend +from timer import timer + +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.utils import str2bool + + +def get_predictor(args, filed='am'): + full_name = '' + if filed == 'am': + full_name = args.am + elif filed == 'voc': + full_name = args.voc + model_name = full_name[:full_name.rindex('_')] + config = inference.Config( + str(Path(args.inference_dir) / (full_name + ".pdmodel")), + str(Path(args.inference_dir) / (full_name + ".pdiparams"))) + if args.device == "gpu": + config.enable_use_gpu(100, 0) + elif args.device == "cpu": + config.disable_gpu() + # This line must be commented for fastspeech2, if not, it will OOM + if model_name != 'fastspeech2': + config.enable_memory_optim() + predictor = inference.create_predictor(config) + return predictor -# only inference for models trained with csmsc now -def main(): +def get_am_output(args, am_predictor, frontend, merge_sentences, input): + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + am_input_names = am_predictor.get_input_names() + get_tone_ids = False + get_spk_id = False + if am_name == 'speedyspeech': + get_tone_ids = True + if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + get_spk_id = True + spk_id = numpy.array([args.spk_id]) + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + tones = tone_ids[0].numpy() + tones_handle = am_predictor.get_input_handle(am_input_names[1]) + tones_handle.reshape(tones.shape) + tones_handle.copy_from_cpu(tones) + if get_spk_id: + spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) + spk_id_handle.reshape(spk_id.shape) + spk_id_handle.copy_from_cpu(spk_id) + phones = phone_ids[0].numpy() + phones_handle = am_predictor.get_input_handle(am_input_names[0]) + phones_handle.reshape(phones.shape) + phones_handle.copy_from_cpu(phones) + + am_predictor.run() + am_output_names = am_predictor.get_output_names() + am_output_handle = am_predictor.get_output_handle(am_output_names[0]) + am_output_data = am_output_handle.copy_to_cpu() + return am_output_data + + +def get_voc_output(args, voc_predictor, input): + voc_input_names = voc_predictor.get_input_names() + mel_handle = voc_predictor.get_input_handle(voc_input_names[0]) + mel_handle.reshape(input.shape) + mel_handle.copy_from_cpu(input) + + voc_predictor.run() + voc_output_names = voc_predictor.get_output_names() + voc_output_handle = voc_predictor.get_output_handle(voc_output_names[0]) + wav = voc_output_handle.copy_to_cpu() + return wav + + +def parse_args(): parser = argparse.ArgumentParser( description="Paddle Infernce with speedyspeech & parallel wavegan.") # acoustic model @@ -70,113 +149,97 @@ def main(): parser.add_argument( "--inference_dir", type=str, help="dir to save inference models") parser.add_argument("--output_dir", type=str, help="output dir") + # inference + parser.add_argument( + "--use_trt", + type=str2bool, + default=False, + help="Whether to use inference engin TensorRT.", ) + parser.add_argument( + "--int8", + type=str2bool, + default=False, + help="Whether to use int8 inference.", ) + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to use float16 inference.", ) + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu"], + help="Device selected for inference.", ) args, _ = parser.parse_known_args() + return args + +# only inference for models trained with csmsc now +def main(): + args = parse_args() # frontend - if args.lang == 'zh': - frontend = Frontend( - phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) - elif args.lang == 'en': - frontend = English(phone_vocab_path=args.phones_dict) - print("frontend done!") + frontend = get_frontend(args) + # am_predictor + am_predictor = get_predictor(args, filed='am') # model: {model_name}_{dataset} - am_name = args.am[:args.am.rindex('_')] am_dataset = args.am[args.am.rindex('_') + 1:] - am_config = inference.Config( - str(Path(args.inference_dir) / (args.am + ".pdmodel")), - str(Path(args.inference_dir) / (args.am + ".pdiparams"))) - am_config.enable_use_gpu(100, 0) - # This line must be commented for fastspeech2, if not, it will OOM - if am_name != 'fastspeech2': - am_config.enable_memory_optim() - am_predictor = inference.create_predictor(am_config) - - voc_config = inference.Config( - str(Path(args.inference_dir) / (args.voc + ".pdmodel")), - str(Path(args.inference_dir) / (args.voc + ".pdiparams"))) - voc_config.enable_use_gpu(100, 0) - voc_config.enable_memory_optim() - voc_predictor = inference.create_predictor(voc_config) + # voc_predictor + voc_predictor = get_predictor(args, filed='voc') output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - sentences = [] - - print("in new inference") - - # construct dataset for evaluation - sentences = [] - with open(args.text, 'rt') as f: - for line in f: - items = line.strip().split() - utt_id = items[0] - if args.lang == 'zh': - sentence = "".join(items[1:]) - elif args.lang == 'en': - sentence = " ".join(items[1:]) - sentences.append((utt_id, sentence)) - get_tone_ids = False - get_spk_id = False - if am_name == 'speedyspeech': - get_tone_ids = True - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: - get_spk_id = True - spk_id = numpy.array([args.spk_id]) + sentences = get_sentences(args) - am_input_names = am_predictor.get_input_names() - print("am_input_names:", am_input_names) merge_sentences = True + fs = 24000 if am_dataset != 'ljspeech' else 22050 + # warmup + for utt_id, sentence in sentences[:3]: + with timer() as t: + am_output_data = get_am_output( + args, + am_predictor=am_predictor, + frontend=frontend, + merge_sentences=merge_sentences, + input=sentence) + wav = get_voc_output( + args, voc_predictor=voc_predictor, input=am_output_data) + speed = wav.size / t.elapse + rtf = fs / speed + print( + f"{utt_id}, mel: {am_output_data.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + + print("warm up done!") + + N = 0 + T = 0 for utt_id, sentence in sentences: - if args.lang == 'zh': - input_ids = frontend.get_input_ids( - sentence, + with timer() as t: + am_output_data = get_am_output( + args, + am_predictor=am_predictor, + frontend=frontend, merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - elif args.lang == 'en': - input_ids = frontend.get_input_ids( - sentence, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - else: - print("lang should in {'zh', 'en'}!") - - if get_tone_ids: - tone_ids = input_ids["tone_ids"] - tones = tone_ids[0].numpy() - tones_handle = am_predictor.get_input_handle(am_input_names[1]) - tones_handle.reshape(tones.shape) - tones_handle.copy_from_cpu(tones) - if get_spk_id: - spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) - spk_id_handle.reshape(spk_id.shape) - spk_id_handle.copy_from_cpu(spk_id) - phones = phone_ids[0].numpy() - phones_handle = am_predictor.get_input_handle(am_input_names[0]) - phones_handle.reshape(phones.shape) - phones_handle.copy_from_cpu(phones) - - am_predictor.run() - am_output_names = am_predictor.get_output_names() - am_output_handle = am_predictor.get_output_handle(am_output_names[0]) - am_output_data = am_output_handle.copy_to_cpu() - - voc_input_names = voc_predictor.get_input_names() - mel_handle = voc_predictor.get_input_handle(voc_input_names[0]) - mel_handle.reshape(am_output_data.shape) - mel_handle.copy_from_cpu(am_output_data) - - voc_predictor.run() - voc_output_names = voc_predictor.get_output_names() - voc_output_handle = voc_predictor.get_output_handle(voc_output_names[0]) - wav = voc_output_handle.copy_to_cpu() + input=sentence) + wav = get_voc_output( + args, voc_predictor=voc_predictor, input=am_output_data) + + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = fs / speed sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000) + print( + f"{utt_id}, mel: {am_output_data.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") if __name__ == "__main__": diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py new file mode 100644 index 00000000000..c52cb372710 --- /dev/null +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -0,0 +1,243 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import paddle +from paddle import jit +from paddle.static import InputSpec + +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.t2s.modules.normalizer import ZScore + +model_alias = { + # acoustic model + "speedyspeech": + "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", + "speedyspeech_inference": + "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + "tacotron2": + "paddlespeech.t2s.models.tacotron2:Tacotron2", + "tacotron2_inference": + "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", + # voc + "pwgan": + "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", + "pwgan_inference": + "paddlespeech.t2s.models.parallel_wavegan:PWGInference", + "mb_melgan": + "paddlespeech.t2s.models.melgan:MelGANGenerator", + "mb_melgan_inference": + "paddlespeech.t2s.models.melgan:MelGANInference", + "style_melgan": + "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", + "style_melgan_inference": + "paddlespeech.t2s.models.melgan:StyleMelGANInference", + "hifigan": + "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", + "hifigan_inference": + "paddlespeech.t2s.models.hifigan:HiFiGANInference", + "wavernn": + "paddlespeech.t2s.models.wavernn:WaveRNN", + "wavernn_inference": + "paddlespeech.t2s.models.wavernn:WaveRNNInference", +} + + +# input +def get_sentences(args): + # construct dataset for evaluation + sentences = [] + with open(args.text, 'rt') as f: + for line in f: + items = line.strip().split() + utt_id = items[0] + if 'lang' in args and args.lang == 'zh': + sentence = "".join(items[1:]) + elif 'lang' in args and args.lang == 'en': + sentence = " ".join(items[1:]) + sentences.append((utt_id, sentence)) + return sentences + + +def get_test_dataset(args, test_metadata, am_name, am_dataset): + if am_name == 'fastspeech2': + fields = ["utt_id", "text"] + if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + print("multiple speaker fastspeech2!") + fields += ["spk_id"] + elif 'voice_cloning' in args and args.voice_cloning: + print("voice cloning!") + fields += ["spk_emb"] + else: + print("single speaker fastspeech2!") + elif am_name == 'speedyspeech': + fields = ["utt_id", "phones", "tones"] + elif am_name == 'tacotron2': + fields = ["utt_id", "text"] + if 'voice_cloning' in args and args.voice_cloning: + print("voice cloning!") + fields += ["spk_emb"] + + test_dataset = DataTable(data=test_metadata, fields=fields) + return test_dataset + + +# frontend +def get_frontend(args): + if 'lang' in args and args.lang == 'zh': + frontend = Frontend( + phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) + elif 'lang' in args and args.lang == 'en': + frontend = English(phone_vocab_path=args.phones_dict) + else: + print("wrong lang!") + print("frontend done!") + return frontend + + +# dygraph +def get_am_inference(args, am_config): + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + tone_size = None + if 'tones_dict' in args and args.tones_dict: + with open(args.tones_dict, "r") as f: + tone_id = [line.strip().split() for line in f.readlines()] + tone_size = len(tone_id) + print("tone_size:", tone_size) + + spk_num = None + if 'speaker_dict' in args and args.speaker_dict: + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + print("spk_num:", spk_num) + + odim = am_config.n_mels + # model: {model_name}_{dataset} + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + + am_class = dynamic_import(am_name, model_alias) + am_inference_class = dynamic_import(am_name + '_inference', model_alias) + + if am_name == 'fastspeech2': + am = am_class( + idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) + elif am_name == 'speedyspeech': + am = am_class( + vocab_size=vocab_size, + tone_size=tone_size, + spk_num=spk_num, + **am_config["model"]) + elif am_name == 'tacotron2': + am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) + + am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) + am.eval() + am_mu, am_std = np.load(args.am_stat) + am_mu = paddle.to_tensor(am_mu) + am_std = paddle.to_tensor(am_std) + am_normalizer = ZScore(am_mu, am_std) + am_inference = am_inference_class(am_normalizer, am) + am_inference.eval() + print("acoustic model done!") + return am_inference, am_name, am_dataset + + +def get_voc_inference(args, voc_config): + # model: {model_name}_{dataset} + voc_name = args.voc[:args.voc.rindex('_')] + voc_class = dynamic_import(voc_name, model_alias) + voc_inference_class = dynamic_import(voc_name + '_inference', model_alias) + if voc_name != 'wavernn': + voc = voc_class(**voc_config["generator_params"]) + voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"]) + voc.remove_weight_norm() + voc.eval() + else: + voc = voc_class(**voc_config["model"]) + voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"]) + voc.eval() + + voc_mu, voc_std = np.load(args.voc_stat) + voc_mu = paddle.to_tensor(voc_mu) + voc_std = paddle.to_tensor(voc_std) + voc_normalizer = ZScore(voc_mu, voc_std) + voc_inference = voc_inference_class(voc_normalizer, voc) + voc_inference.eval() + print("voc done!") + return voc_inference + + +# to static +def am_to_static(args, am_inference, am_name, am_dataset): + if am_name == 'fastspeech2': + if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + am_inference = jit.to_static( + am_inference, + input_spec=[ + InputSpec([-1], dtype=paddle.int64), + InputSpec([1], dtype=paddle.int64), + ], ) + else: + am_inference = jit.to_static( + am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) + + elif am_name == 'speedyspeech': + if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + am_inference = jit.to_static( + am_inference, + input_spec=[ + InputSpec([-1], dtype=paddle.int64), # text + InputSpec([-1], dtype=paddle.int64), # tone + InputSpec([1], dtype=paddle.int64), # spk_id + None # duration + ]) + else: + am_inference = jit.to_static( + am_inference, + input_spec=[ + InputSpec([-1], dtype=paddle.int64), + InputSpec([-1], dtype=paddle.int64) + ]) + + elif am_name == 'tacotron2': + am_inference = jit.to_static( + am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) + + paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am)) + am_inference = paddle.jit.load(os.path.join(args.inference_dir, args.am)) + return am_inference + + +def voc_to_static(args, voc_inference): + voc_inference = jit.to_static( + voc_inference, input_spec=[ + InputSpec([-1, 80], dtype=paddle.float32), + ]) + paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc)) + voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc)) + return voc_inference diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index 81da14f2eae..abb1eb4eb6e 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -23,48 +23,11 @@ from timer import timer from yacs.config import CfgNode -from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.t2s.datasets.data_table import DataTable -from paddlespeech.t2s.modules.normalizer import ZScore +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_test_dataset +from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.utils import str2bool -model_alias = { - # acoustic model - "speedyspeech": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", - "speedyspeech_inference": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", - "fastspeech2": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2", - "fastspeech2_inference": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", - "tacotron2": - "paddlespeech.t2s.models.tacotron2:Tacotron2", - "tacotron2_inference": - "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", - # voc - "pwgan": - "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", - "pwgan_inference": - "paddlespeech.t2s.models.parallel_wavegan:PWGInference", - "mb_melgan": - "paddlespeech.t2s.models.melgan:MelGANGenerator", - "mb_melgan_inference": - "paddlespeech.t2s.models.melgan:MelGANInference", - "style_melgan": - "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", - "style_melgan_inference": - "paddlespeech.t2s.models.melgan:StyleMelGANInference", - "hifigan": - "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", - "hifigan_inference": - "paddlespeech.t2s.models.hifigan:HiFiGANInference", - "wavernn": - "paddlespeech.t2s.models.wavernn:WaveRNN", - "wavernn_inference": - "paddlespeech.t2s.models.wavernn:WaveRNNInference", -} - def evaluate(args): # dataloader has been too verbose @@ -86,96 +49,12 @@ def evaluate(args): print(am_config) print(voc_config) - # construct dataset for evaluation - - # model: {model_name}_{dataset} - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] - - if am_name == 'fastspeech2': - fields = ["utt_id", "text"] - spk_num = None - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: - print("multiple speaker fastspeech2!") - with open(args.speaker_dict, 'rt') as f: - spk_id = [line.strip().split() for line in f.readlines()] - spk_num = len(spk_id) - fields += ["spk_id"] - elif args.voice_cloning: - print("voice cloning!") - fields += ["spk_emb"] - else: - print("single speaker fastspeech2!") - print("spk_num:", spk_num) - elif am_name == 'speedyspeech': - fields = ["utt_id", "phones", "tones"] - elif am_name == 'tacotron2': - fields = ["utt_id", "text"] - if args.voice_cloning: - print("voice cloning!") - fields += ["spk_emb"] - - test_dataset = DataTable(data=test_metadata, fields=fields) - - with open(args.phones_dict, "r") as f: - phn_id = [line.strip().split() for line in f.readlines()] - vocab_size = len(phn_id) - print("vocab_size:", vocab_size) - - tone_size = None - if args.tones_dict: - with open(args.tones_dict, "r") as f: - tone_id = [line.strip().split() for line in f.readlines()] - tone_size = len(tone_id) - print("tone_size:", tone_size) - # acoustic model - odim = am_config.n_mels - am_class = dynamic_import(am_name, model_alias) - am_inference_class = dynamic_import(am_name + '_inference', model_alias) - - if am_name == 'fastspeech2': - am = am_class( - idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) - elif am_name == 'speedyspeech': - am = am_class( - vocab_size=vocab_size, tone_size=tone_size, **am_config["model"]) - elif am_name == 'tacotron2': - am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) - - am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) - am.eval() - am_mu, am_std = np.load(args.am_stat) - am_mu = paddle.to_tensor(am_mu) - am_std = paddle.to_tensor(am_std) - am_normalizer = ZScore(am_mu, am_std) - am_inference = am_inference_class(am_normalizer, am) - print("am_inference.training0:", am_inference.training) - am_inference.eval() - print("acoustic model done!") + am_inference, am_name, am_dataset = get_am_inference(args, am_config) + test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset) # vocoder - # model: {model_name}_{dataset} - voc_name = args.voc[:args.voc.rindex('_')] - voc_class = dynamic_import(voc_name, model_alias) - voc_inference_class = dynamic_import(voc_name + '_inference', model_alias) - if voc_name != 'wavernn': - voc = voc_class(**voc_config["generator_params"]) - voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"]) - voc.remove_weight_norm() - voc.eval() - else: - voc = voc_class(**voc_config["model"]) - voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"]) - voc.eval() - voc_mu, voc_std = np.load(args.voc_stat) - voc_mu = paddle.to_tensor(voc_mu) - voc_std = paddle.to_tensor(voc_std) - voc_normalizer = ZScore(voc_mu, voc_std) - voc_inference = voc_inference_class(voc_normalizer, voc) - print("voc_inference.training0:", voc_inference.training) - voc_inference.eval() - print("voc done!") + voc_inference = get_voc_inference(args, voc_config) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -227,7 +106,7 @@ def evaluate(args): print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }") -def main(): +def parse_args(): # parse args and config and redirect to train_sp parser = argparse.ArgumentParser( description="Synthesize with acoustic model & vocoder") @@ -264,7 +143,6 @@ def main(): "--tones_dict", type=str, default=None, help="tone vocabulary file.") parser.add_argument( "--speaker_dict", type=str, default=None, help="speaker id map file.") - parser.add_argument( "--voice-cloning", type=str2bool, @@ -278,10 +156,10 @@ def main(): choices=[ 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', 'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc', + 'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk', 'style_melgan_csmsc' ], help='Choose vocoder type of tts task.') - parser.add_argument( '--voc_config', type=str, @@ -302,7 +180,12 @@ def main(): parser.add_argument("--output_dir", type=str, help="output dir.") args = parser.parse_args() + return args + + +def main(): + args = parse_args() if args.ngpu == 0: paddle.set_device("cpu") elif args.ngpu > 0: diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 94180f8531a..10b33c60acf 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -12,59 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -import os from pathlib import Path -import numpy as np import paddle import soundfile as sf import yaml -from paddle import jit -from paddle.static import InputSpec from timer import timer from yacs.config import CfgNode -from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.t2s.frontend import English -from paddlespeech.t2s.frontend.zh_frontend import Frontend -from paddlespeech.t2s.modules.normalizer import ZScore - -model_alias = { - # acoustic model - "speedyspeech": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", - "speedyspeech_inference": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", - "fastspeech2": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2", - "fastspeech2_inference": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", - "tacotron2": - "paddlespeech.t2s.models.tacotron2:Tacotron2", - "tacotron2_inference": - "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", - # voc - "pwgan": - "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", - "pwgan_inference": - "paddlespeech.t2s.models.parallel_wavegan:PWGInference", - "mb_melgan": - "paddlespeech.t2s.models.melgan:MelGANGenerator", - "mb_melgan_inference": - "paddlespeech.t2s.models.melgan:MelGANInference", - "style_melgan": - "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", - "style_melgan_inference": - "paddlespeech.t2s.models.melgan:StyleMelGANInference", - "hifigan": - "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", - "hifigan_inference": - "paddlespeech.t2s.models.hifigan:HiFiGANInference", - "wavernn": - "paddlespeech.t2s.models.wavernn:WaveRNN", - "wavernn_inference": - "paddlespeech.t2s.models.wavernn:WaveRNNInference", -} +from paddlespeech.t2s.exps.syn_utils import am_to_static +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_voc_inference +from paddlespeech.t2s.exps.syn_utils import voc_to_static def evaluate(args): @@ -81,151 +42,24 @@ def evaluate(args): print(am_config) print(voc_config) - # construct dataset for evaluation - sentences = [] - with open(args.text, 'rt') as f: - for line in f: - items = line.strip().split() - utt_id = items[0] - if args.lang == 'zh': - sentence = "".join(items[1:]) - elif args.lang == 'en': - sentence = " ".join(items[1:]) - sentences.append((utt_id, sentence)) - - with open(args.phones_dict, "r") as f: - phn_id = [line.strip().split() for line in f.readlines()] - vocab_size = len(phn_id) - print("vocab_size:", vocab_size) - - tone_size = None - if args.tones_dict: - with open(args.tones_dict, "r") as f: - tone_id = [line.strip().split() for line in f.readlines()] - tone_size = len(tone_id) - print("tone_size:", tone_size) - - spk_num = None - if args.speaker_dict: - with open(args.speaker_dict, 'rt') as f: - spk_id = [line.strip().split() for line in f.readlines()] - spk_num = len(spk_id) - print("spk_num:", spk_num) + sentences = get_sentences(args) # frontend - if args.lang == 'zh': - frontend = Frontend( - phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) - elif args.lang == 'en': - frontend = English(phone_vocab_path=args.phones_dict) - print("frontend done!") + frontend = get_frontend(args) # acoustic model - odim = am_config.n_mels - # model: {model_name}_{dataset} - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] - - am_class = dynamic_import(am_name, model_alias) - am_inference_class = dynamic_import(am_name + '_inference', model_alias) - - if am_name == 'fastspeech2': - am = am_class( - idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) - elif am_name == 'speedyspeech': - am = am_class( - vocab_size=vocab_size, - tone_size=tone_size, - spk_num=spk_num, - **am_config["model"]) - elif am_name == 'tacotron2': - am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) - - am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) - am.eval() - am_mu, am_std = np.load(args.am_stat) - am_mu = paddle.to_tensor(am_mu) - am_std = paddle.to_tensor(am_std) - am_normalizer = ZScore(am_mu, am_std) - am_inference = am_inference_class(am_normalizer, am) - am_inference.eval() - print("acoustic model done!") + am_inference, am_name, am_dataset = get_am_inference(args, am_config) # vocoder - # model: {model_name}_{dataset} - voc_name = args.voc[:args.voc.rindex('_')] - voc_class = dynamic_import(voc_name, model_alias) - voc_inference_class = dynamic_import(voc_name + '_inference', model_alias) - if voc_name != 'wavernn': - voc = voc_class(**voc_config["generator_params"]) - voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"]) - voc.remove_weight_norm() - voc.eval() - else: - voc = voc_class(**voc_config["model"]) - voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"]) - voc.eval() - - voc_mu, voc_std = np.load(args.voc_stat) - voc_mu = paddle.to_tensor(voc_mu) - voc_std = paddle.to_tensor(voc_std) - voc_normalizer = ZScore(voc_mu, voc_std) - voc_inference = voc_inference_class(voc_normalizer, voc) - voc_inference.eval() - print("voc done!") + voc_inference = get_voc_inference(args, voc_config) # whether dygraph to static if args.inference_dir: # acoustic model - if am_name == 'fastspeech2': - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: - am_inference = jit.to_static( - am_inference, - input_spec=[ - InputSpec([-1], dtype=paddle.int64), - InputSpec([1], dtype=paddle.int64) - ]) - else: - am_inference = jit.to_static( - am_inference, - input_spec=[InputSpec([-1], dtype=paddle.int64)]) - - elif am_name == 'speedyspeech': - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: - am_inference = jit.to_static( - am_inference, - input_spec=[ - InputSpec([-1], dtype=paddle.int64), # text - InputSpec([-1], dtype=paddle.int64), # tone - InputSpec([1], dtype=paddle.int64), # spk_id - None # duration - ]) - else: - am_inference = jit.to_static( - am_inference, - input_spec=[ - InputSpec([-1], dtype=paddle.int64), - InputSpec([-1], dtype=paddle.int64) - ]) - - elif am_name == 'tacotron2': - am_inference = jit.to_static( - am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) - - paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am)) - am_inference = paddle.jit.load( - os.path.join(args.inference_dir, args.am)) + am_inference = am_to_static(args, am_inference, am_name, am_dataset) # vocoder - voc_inference = jit.to_static( - voc_inference, - input_spec=[ - InputSpec([-1, 80], dtype=paddle.float32), - ]) - paddle.jit.save(voc_inference, - os.path.join(args.inference_dir, args.voc)) - voc_inference = paddle.jit.load( - os.path.join(args.inference_dir, args.voc)) + voc_inference = voc_to_static(args, voc_inference) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -298,7 +132,7 @@ def evaluate(args): print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }") -def main(): +def parse_args(): # parse args and config and redirect to train_sp parser = argparse.ArgumentParser( description="Synthesize with acoustic model & vocoder") @@ -346,12 +180,19 @@ def main(): type=str, default='pwgan_csmsc', choices=[ - 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', - 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc', - 'wavernn_csmsc' + 'pwgan_csmsc', + 'pwgan_ljspeech', + 'pwgan_aishell3', + 'pwgan_vctk', + 'mb_melgan_csmsc', + 'style_melgan_csmsc', + 'hifigan_csmsc', + 'hifigan_ljspeech', + 'hifigan_aishell3', + 'hifigan_vctk', + 'wavernn_csmsc', ], help='Choose vocoder type of tts task.') - parser.add_argument( '--voc_config', type=str, @@ -386,6 +227,11 @@ def main(): parser.add_argument("--output_dir", type=str, help="output dir.") args = parser.parse_args() + return args + + +def main(): + args = parse_args() if args.ngpu == 0: paddle.set_device("cpu") diff --git a/paddlespeech/t2s/exps/voice_cloning.py b/paddlespeech/t2s/exps/voice_cloning.py index 3de30774f5b..1afd21dfffb 100644 --- a/paddlespeech/t2s/exps/voice_cloning.py +++ b/paddlespeech/t2s/exps/voice_cloning.py @@ -21,29 +21,12 @@ import yaml from yacs.config import CfgNode -from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.frontend.zh_frontend import Frontend -from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder -model_alias = { - # acoustic model - "fastspeech2": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2", - "fastspeech2_inference": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", - "tacotron2": - "paddlespeech.t2s.models.tacotron2:Tacotron2", - "tacotron2_inference": - "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", - # voc - "pwgan": - "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", - "pwgan_inference": - "paddlespeech.t2s.models.parallel_wavegan:PWGInference", -} - def voice_cloning(args): # Init body. @@ -79,55 +62,14 @@ def voice_cloning(args): speaker_encoder.eval() print("GE2E Done!") - with open(args.phones_dict, "r") as f: - phn_id = [line.strip().split() for line in f.readlines()] - vocab_size = len(phn_id) - print("vocab_size:", vocab_size) + frontend = Frontend(phone_vocab_path=args.phones_dict) + print("frontend done!") # acoustic model - odim = am_config.n_mels - # model: {model_name}_{dataset} - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] - - am_class = dynamic_import(am_name, model_alias) - am_inference_class = dynamic_import(am_name + '_inference', model_alias) - - if am_name == 'fastspeech2': - am = am_class( - idim=vocab_size, odim=odim, spk_num=None, **am_config["model"]) - elif am_name == 'tacotron2': - am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) - - am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) - am.eval() - am_mu, am_std = np.load(args.am_stat) - am_mu = paddle.to_tensor(am_mu) - am_std = paddle.to_tensor(am_std) - am_normalizer = ZScore(am_mu, am_std) - am_inference = am_inference_class(am_normalizer, am) - am_inference.eval() - print("acoustic model done!") + am_inference, *_ = get_am_inference(args, am_config) # vocoder - # model: {model_name}_{dataset} - voc_name = args.voc[:args.voc.rindex('_')] - voc_class = dynamic_import(voc_name, model_alias) - voc_inference_class = dynamic_import(voc_name + '_inference', model_alias) - voc = voc_class(**voc_config["generator_params"]) - voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"]) - voc.remove_weight_norm() - voc.eval() - voc_mu, voc_std = np.load(args.voc_stat) - voc_mu = paddle.to_tensor(voc_mu) - voc_std = paddle.to_tensor(voc_std) - voc_normalizer = ZScore(voc_mu, voc_std) - voc_inference = voc_inference_class(voc_normalizer, voc) - voc_inference.eval() - print("voc done!") - - frontend = Frontend(phone_vocab_path=args.phones_dict) - print("frontend done!") + voc_inference = get_voc_inference(args, voc_config) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -170,7 +112,7 @@ def voice_cloning(args): print(f"{utt_id} done!") -def main(): +def parse_args(): # parse args and config and redirect to train_sp parser = argparse.ArgumentParser(description="") parser.add_argument( @@ -240,6 +182,11 @@ def main(): parser.add_argument("--output-dir", type=str, help="output dir.") args = parser.parse_args() + return args + + +def main(): + args = parse_args() if args.ngpu == 0: paddle.set_device("cpu") diff --git a/paddlespeech/t2s/modules/predictor/length_regulator.py b/paddlespeech/t2s/modules/predictor/length_regulator.py index 62d707d2234..2472c413bea 100644 --- a/paddlespeech/t2s/modules/predictor/length_regulator.py +++ b/paddlespeech/t2s/modules/predictor/length_regulator.py @@ -101,6 +101,16 @@ def forward(self, xs, ds, alpha=1.0, is_inference=False): assert alpha > 0 ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha) ds = ds.cast(dtype=paddle.int64) + ''' + from distutils.version import LooseVersion + from paddlespeech.t2s.modules.nets_utils import pad_list + # 这里在 paddle 2.2.2 的动转静是不通的 + # if LooseVersion(paddle.__version__) >= "2.3.0" or hasattr(paddle, 'repeat_interleave'): + # if LooseVersion(paddle.__version__) >= "2.3.0": + if hasattr(paddle, 'repeat_interleave'): + repeat = [paddle.repeat_interleave(x, d, axis=0) for x, d in zip(xs, ds)] + return pad_list(repeat, self.pad_value) + ''' if is_inference: return self.expand(xs, ds) else: diff --git a/paddlespeech/vector/cluster/diarization.py b/paddlespeech/vector/cluster/diarization.py new file mode 100644 index 00000000000..6432acb8169 --- /dev/null +++ b/paddlespeech/vector/cluster/diarization.py @@ -0,0 +1,1082 @@ +# Copyright (c) 2022 SpeechBrain Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script contains basic functions used for speaker diarization. +This script has an optional dependency on open source sklearn library. +A few sklearn functions are modified in this script as per requirement. +""" + +import argparse +import warnings +import scipy +import numpy as np +from distutils.util import strtobool + +from scipy import sparse +from scipy.sparse.linalg import eigsh +from scipy.sparse.csgraph import connected_components +from scipy.sparse.csgraph import laplacian as csgraph_laplacian + +import sklearn +from sklearn.neighbors import kneighbors_graph +from sklearn.cluster import SpectralClustering +from sklearn.cluster._kmeans import k_means + + +def _graph_connected_component(graph, node_id): + """ + Find the largest graph connected components that contains one + given node. + + Arguments + --------- + graph : array-like, shape: (n_samples, n_samples) + Adjacency matrix of the graph, non-zero weight means an edge + between the nodes. + node_id : int + The index of the query node of the graph. + + Returns + ------- + connected_components_matrix : array-like + shape - (n_samples,). + An array of bool value indicating the indexes of the nodes belonging + to the largest connected components of the given query node. + """ + + n_node = graph.shape[0] + if sparse.issparse(graph): + # speed up row-wise access to boolean connection mask + graph = graph.tocsr() + connected_nodes = np.zeros(n_node, dtype=bool) + nodes_to_explore = np.zeros(n_node, dtype=bool) + nodes_to_explore[node_id] = True + for _ in range(n_node): + last_num_component = connected_nodes.sum() + np.logical_or(connected_nodes, nodes_to_explore, out=connected_nodes) + if last_num_component >= connected_nodes.sum(): + break + indices = np.where(nodes_to_explore)[0] + nodes_to_explore.fill(False) + for i in indices: + if sparse.issparse(graph): + neighbors = graph[i].toarray().ravel() + else: + neighbors = graph[i] + np.logical_or(nodes_to_explore, neighbors, out=nodes_to_explore) + return connected_nodes + + +def _graph_is_connected(graph): + """ + Return whether the graph is connected (True) or Not (False) + + Arguments + --------- + graph : array-like or sparse matrix, shape: (n_samples, n_samples) + Adjacency matrix of the graph, non-zero weight means an edge between the nodes. + + Returns + ------- + is_connected : bool + True means the graph is fully connected and False means not. + """ + + if sparse.isspmatrix(graph): + # sparse graph, find all the connected components + n_connected_components, _ = connected_components(graph) + return n_connected_components == 1 + else: + # dense graph, find all connected components start from node 0 + return _graph_connected_component(graph, 0).sum() == graph.shape[0] + + +def _set_diag(laplacian, value, norm_laplacian): + """ + Set the diagonal of the laplacian matrix and convert it to a sparse + format well suited for eigenvalue decomposition. + + Arguments + --------- + laplacian : array or sparse matrix + The graph laplacian. + value : float + The value of the diagonal. + norm_laplacian : bool + Whether the value of the diagonal should be changed or not. + + Returns + ------- + laplacian : array or sparse matrix + An array of matrix in a form that is well suited to fast eigenvalue + decomposition, depending on the bandwidth of the matrix. + """ + + n_nodes = laplacian.shape[0] + # We need all entries in the diagonal to values + if not sparse.isspmatrix(laplacian): + if norm_laplacian: + laplacian.flat[::n_nodes + 1] = value + else: + laplacian = laplacian.tocoo() + if norm_laplacian: + diag_idx = laplacian.row == laplacian.col + laplacian.data[diag_idx] = value + # If the matrix has a small number of diagonals (as in the + # case of structured matrices coming from images), the + # dia format might be best suited for matvec products: + n_diags = np.unique(laplacian.row - laplacian.col).size + if n_diags <= 7: + # 3 or less outer diagonals on each side + laplacian = laplacian.todia() + else: + # csr has the fastest matvec and is thus best suited to + # arpack + laplacian = laplacian.tocsr() + return laplacian + + +def _deterministic_vector_sign_flip(u): + """ + Modify the sign of vectors for reproducibility. Flips the sign of + elements of all the vectors (rows of u) such that the absolute + maximum element of each vector is positive. + + Arguments + --------- + u : ndarray + Array with vectors as its rows. + + Returns + ------- + u_flipped : ndarray + Array with the sign flipped vectors as its rows. The same shape as `u`. + """ + + max_abs_rows = np.argmax(np.abs(u), axis=1) + signs = np.sign(u[range(u.shape[0]), max_abs_rows]) + u *= signs[:, np.newaxis] + return u + + +def _check_random_state(seed): + """ + Turn seed into a np.random.RandomState instance. + + Arguments + --------- + seed : None | int | instance of RandomState + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + """ + + if seed is None or seed is np.random: + return np.random.mtrand._rand + if isinstance(seed, numbers.Integral): + return np.random.RandomState(seed) + if isinstance(seed, np.random.RandomState): + return seed + raise ValueError("%r cannot be used to seed a np.random.RandomState" + " instance" % seed) + + +def spectral_embedding( + adjacency, + n_components=8, + norm_laplacian=True, + drop_first=True, ): + """ + Returns spectral embeddings. + + Arguments + --------- + adjacency : array-like or sparse graph + shape - (n_samples, n_samples) + The adjacency matrix of the graph to embed. + n_components : int + The dimension of the projection subspace. + norm_laplacian : bool + If True, then compute normalized Laplacian. + drop_first : bool + Whether to drop the first eigenvector. + + Returns + ------- + embedding : array + Spectral embeddings for each sample. + + Example + ------- + >>> import numpy as np + >>> import diarization as diar + >>> affinity = np.array([[1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5], + ... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + ... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + ... [0.5, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0.5, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) + >>> embs = diar.spectral_embedding(affinity, 3) + >>> # Notice similar embeddings + >>> print(np.around(embs , decimals=3)) + [[ 0.075 0.244 0.285] + [ 0.083 0.356 -0.203] + [ 0.083 0.356 -0.203] + [ 0.26 -0.149 0.154] + [ 0.29 -0.218 -0.11 ] + [ 0.29 -0.218 -0.11 ] + [-0.198 -0.084 -0.122] + [-0.198 -0.084 -0.122] + [-0.198 -0.084 -0.122] + [-0.167 -0.044 0.316]] + """ + + # Whether to drop the first eigenvector + if drop_first: + n_components = n_components + 1 + + if not _graph_is_connected(adjacency): + warnings.warn("Graph is not fully connected, spectral embedding" + " may not work as expected.") + + laplacian, dd = csgraph_laplacian( + adjacency, normed=norm_laplacian, return_diag=True) + + laplacian = _set_diag(laplacian, 1, norm_laplacian) + + laplacian *= -1 + + vals, diffusion_map = eigsh( + laplacian, + k=n_components, + sigma=1.0, + which="LM", ) + + embedding = diffusion_map.T[n_components::-1] + + if norm_laplacian: + embedding = embedding / dd + + embedding = _deterministic_vector_sign_flip(embedding) + if drop_first: + return embedding[1:n_components].T + else: + return embedding[:n_components].T + + +def spectral_clustering( + affinity, + n_clusters=8, + n_components=None, + random_state=None, + n_init=10, ): + """ + Performs spectral clustering. + + Arguments + --------- + affinity : matrix + Affinity matrix. + n_clusters : int + Number of clusters for kmeans. + n_components : int + Number of components to retain while estimating spectral embeddings. + random_state : int + A pseudo random number generator used by kmeans. + n_init : int + Number of time the k-means algorithm will be run with different centroid seeds. + + Returns + ------- + labels : array + Cluster label for each sample. + + Example + ------- + >>> import numpy as np + >>> diarization as diar + >>> affinity = np.array([[1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5], + ... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + ... [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + ... [0.5, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0, 0, 0, 0, 0, 0, 1, 1, 1, 1], + ... [0.5, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) + >>> labs = diar.spectral_clustering(affinity, 3) + >>> # print (labs) # [2 2 2 1 1 1 0 0 0 0] + """ + + random_state = _check_random_state(random_state) + n_components = n_clusters if n_components is None else n_components + + maps = spectral_embedding( + affinity, + n_components=n_components, + drop_first=False, ) + + _, labels, _ = k_means( + maps, n_clusters, random_state=random_state, n_init=n_init) + + return labels + + +class EmbeddingMeta: + """ + A utility class to pack deep embeddings and meta-information in one object. + + Arguments + --------- + segset : list + List of session IDs as an array of strings. + stats : tensor + An ndarray of float64. Each line contains embedding + from the corresponding session. + """ + + def __init__( + self, + segset=None, + stats=None, ): + + if segset is None: + self.segset = numpy.empty(0, dtype="|O") + self.stats = numpy.array([], dtype=np.float64) + else: + self.segset = segset + self.stats = stats + + def norm_stats(self): + """ + Divide all first-order statistics by their Euclidean norm. + """ + + vect_norm = np.clip(np.linalg.norm(self.stats, axis=1), 1e-08, np.inf) + self.stats = (self.stats.transpose() / vect_norm).transpose() + + +class SpecClustUnorm: + """ + This class implements the spectral clustering with unnormalized affinity matrix. + Useful when affinity matrix is based on cosine similarities. + + Reference + --------- + Von Luxburg, U. A tutorial on spectral clustering. Stat Comput 17, 395–416 (2007). + https://doi.org/10.1007/s11222-007-9033-z + + Example + ------- + >>> import diarization as diar + >>> clust = diar.SpecClustUnorm(min_num_spkrs=2, max_num_spkrs=10) + >>> emb = [[ 2.1, 3.1, 4.1, 4.2, 3.1], + ... [ 2.2, 3.1, 4.2, 4.2, 3.2], + ... [ 2.0, 3.0, 4.0, 4.1, 3.0], + ... [ 8.0, 7.0, 7.0, 8.1, 9.0], + ... [ 8.1, 7.1, 7.2, 8.1, 9.2], + ... [ 8.3, 7.4, 7.0, 8.4, 9.0], + ... [ 0.3, 0.4, 0.4, 0.5, 0.8], + ... [ 0.4, 0.3, 0.6, 0.7, 0.8], + ... [ 0.2, 0.3, 0.2, 0.3, 0.7], + ... [ 0.3, 0.4, 0.4, 0.4, 0.7],] + >>> # Estimating similarity matrix + >>> sim_mat = clust.get_sim_mat(emb) + >>> print (np.around(sim_mat[5:,5:], decimals=3)) + [[1. 0.957 0.961 0.904 0.966] + [0.957 1. 0.977 0.982 0.997] + [0.961 0.977 1. 0.928 0.972] + [0.904 0.982 0.928 1. 0.976] + [0.966 0.997 0.972 0.976 1. ]] + >>> # Prunning + >>> prunned_sim_mat = clust.p_pruning(sim_mat, 0.3) + >>> print (np.around(prunned_sim_mat[5:,5:], decimals=3)) + [[1. 0. 0. 0. 0. ] + [0. 1. 0. 0.982 0.997] + [0. 0.977 1. 0. 0.972] + [0. 0.982 0. 1. 0.976] + [0. 0.997 0. 0.976 1. ]] + >>> # Symmetrization + >>> sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T) + >>> print (np.around(sym_prund_sim_mat[5:,5:], decimals=3)) + [[1. 0. 0. 0. 0. ] + [0. 1. 0.489 0.982 0.997] + [0. 0.489 1. 0. 0.486] + [0. 0.982 0. 1. 0.976] + [0. 0.997 0.486 0.976 1. ]] + >>> # Laplacian + >>> laplacian = clust.get_laplacian(sym_prund_sim_mat) + >>> print (np.around(laplacian[5:,5:], decimals=3)) + [[ 1.999 0. 0. 0. 0. ] + [ 0. 2.468 -0.489 -0.982 -0.997] + [ 0. -0.489 0.975 0. -0.486] + [ 0. -0.982 0. 1.958 -0.976] + [ 0. -0.997 -0.486 -0.976 2.458]] + >>> # Spectral Embeddings + >>> spec_emb, num_of_spk = clust.get_spec_embs(laplacian, 3) + >>> print(num_of_spk) + 3 + >>> # Clustering + >>> clust.cluster_embs(spec_emb, num_of_spk) + >>> # print (clust.labels_) # [0 0 0 2 2 2 1 1 1 1] + >>> # Complete spectral clustering + >>> clust.do_spec_clust(emb, k_oracle=3, p_val=0.3) + >>> # print(clust.labels_) # [0 0 0 2 2 2 1 1 1 1] + """ + + def __init__(self, min_num_spkrs=2, max_num_spkrs=10): + + self.min_num_spkrs = min_num_spkrs + self.max_num_spkrs = max_num_spkrs + + def do_spec_clust(self, X, k_oracle, p_val): + """ + Function for spectral clustering. + + Arguments + --------- + X : array + (n_samples, n_features). + Embeddings extracted from the model. + k_oracle : int + Number of speakers (when oracle number of speakers). + p_val : float + p percent value to prune the affinity matrix. + """ + + # Similarity matrix computation + sim_mat = self.get_sim_mat(X) + + # Refining similarity matrix with p_val + prunned_sim_mat = self.p_pruning(sim_mat, p_val) + + # Symmetrization + sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T) + + # Laplacian calculation + laplacian = self.get_laplacian(sym_prund_sim_mat) + + # Get Spectral Embeddings + emb, num_of_spk = self.get_spec_embs(laplacian, k_oracle) + + # Perform clustering + self.cluster_embs(emb, num_of_spk) + + def get_sim_mat(self, X): + """ + Returns the similarity matrix based on cosine similarities. + + Arguments + --------- + X : array + (n_samples, n_features). + Embeddings extracted from the model. + + Returns + ------- + M : array + (n_samples, n_samples). + Similarity matrix with cosine similarities between each pair of embedding. + """ + + # Cosine similarities + M = sklearn.metrics.pairwise.cosine_similarity(X, X) + return M + + def p_pruning(self, A, pval): + """ + Refine the affinity matrix by zeroing less similar values. + + Arguments + --------- + A : array + (n_samples, n_samples). + Affinity matrix. + pval : float + p-value to be retained in each row of the affinity matrix. + + Returns + ------- + A : array + (n_samples, n_samples). + Prunned affinity matrix based on p_val. + """ + + n_elems = int((1 - pval) * A.shape[0]) + + # For each row in a affinity matrix + for i in range(A.shape[0]): + low_indexes = np.argsort(A[i, :]) + low_indexes = low_indexes[0:n_elems] + + # Replace smaller similarity values by 0s + A[i, low_indexes] = 0 + + return A + + def get_laplacian(self, M): + """ + Returns the un-normalized laplacian for the given affinity matrix. + + Arguments + --------- + M : array + (n_samples, n_samples) + Affinity matrix. + + Returns + ------- + L : array + (n_samples, n_samples) + Laplacian matrix. + """ + + M[np.diag_indices(M.shape[0])] = 0 + D = np.sum(np.abs(M), axis=1) + D = np.diag(D) + L = D - M + return L + + def get_spec_embs(self, L, k_oracle=4): + """ + Returns spectral embeddings and estimates the number of speakers + using maximum Eigen gap. + + Arguments + --------- + L : array (n_samples, n_samples) + Laplacian matrix. + k_oracle : int + Number of speakers when the condition is oracle number of speakers, + else None. + + Returns + ------- + emb : array (n_samples, n_components) + Spectral embedding for each sample with n Eigen components. + num_of_spk : int + Estimated number of speakers. If the condition is set to the oracle + number of speakers then returns k_oracle. + """ + + lambdas, eig_vecs = scipy.linalg.eigh(L) + + # if params["oracle_n_spkrs"] is True: + if k_oracle is not None: + num_of_spk = k_oracle + else: + lambda_gap_list = self.get_eigen_gaps(lambdas[1:self.max_num_spkrs]) + + num_of_spk = (np.argmax( + lambda_gap_list[:min(self.max_num_spkrs, len(lambda_gap_list))]) + + 2) + + if num_of_spk < self.min_num_spkrs: + num_of_spk = self.min_num_spkrs + + emb = eig_vecs[:, 0:num_of_spk] + + return emb, num_of_spk + + def cluster_embs(self, emb, k): + """ + Clusters the embeddings using kmeans. + + Arguments + --------- + emb : array (n_samples, n_components) + Spectral embedding for each sample with n Eigen components. + k : int + Number of clusters to kmeans. + + Returns + ------- + self.labels_ : self + Labels for each sample embedding. + """ + _, self.labels_, _ = k_means(emb, k) + + def get_eigen_gaps(self, eig_vals): + """ + Returns the difference (gaps) between the Eigen values. + + Arguments + --------- + eig_vals : list + List of eigen values + + Returns + ------- + eig_vals_gap_list : list + List of differences (gaps) between adjacent Eigen values. + """ + + eig_vals_gap_list = [] + for i in range(len(eig_vals) - 1): + gap = float(eig_vals[i + 1]) - float(eig_vals[i]) + eig_vals_gap_list.append(gap) + + return eig_vals_gap_list + + +class SpecCluster(SpectralClustering): + def perform_sc(self, X, n_neighbors=10): + """ + Performs spectral clustering using sklearn on embeddings. + + Arguments + --------- + X : array (n_samples, n_features) + Embeddings to be clustered. + n_neighbors : int + Number of neighbors in estimating affinity matrix. + """ + + # Computation of affinity matrix + connectivity = kneighbors_graph( + X, + n_neighbors=n_neighbors, + include_self=True, ) + self.affinity_matrix_ = 0.5 * (connectivity + connectivity.T) + + # Perform spectral clustering on affinity matrix + self.labels_ = spectral_clustering( + self.affinity_matrix_, + n_clusters=self.n_clusters, ) + return self + + +def is_overlapped(end1, start2): + """ + Returns True if segments are overlapping. + + Arguments + --------- + end1 : float + End time of the first segment. + start2 : float + Start time of the second segment. + + Returns + ------- + overlapped : bool + True of segments overlapped else False. + + Example + ------- + >>> import diarization as diar + >>> diar.is_overlapped(5.5, 3.4) + True + >>> diar.is_overlapped(5.5, 6.4) + False + """ + + if start2 > end1: + return False + else: + return True + + +def merge_ssegs_same_speaker(lol): + """ + Merge adjacent sub-segs from the same speaker. + + Arguments + --------- + lol : list of list + Each list contains [rec_id, seg_start, seg_end, spkr_id]. + + Returns + ------- + new_lol : list of list + new_lol contains adjacent segments merged from the same speaker ID. + + Example + ------- + >>> import diarization as diar + >>> lol=[['r1', 5.5, 7.0, 's1'], + ... ['r1', 6.5, 9.0, 's1'], + ... ['r1', 8.0, 11.0, 's1'], + ... ['r1', 11.5, 13.0, 's2'], + ... ['r1', 14.0, 15.0, 's2'], + ... ['r1', 14.5, 15.0, 's1']] + >>> diar.merge_ssegs_same_speaker(lol) + [['r1', 5.5, 11.0, 's1'], ['r1', 11.5, 13.0, 's2'], ['r1', 14.0, 15.0, 's2'], ['r1', 14.5, 15.0, 's1']] + """ + + new_lol = [] + + # Start from the first sub-seg + sseg = lol[0] + flag = False + for i in range(1, len(lol)): + next_sseg = lol[i] + + # IF sub-segments overlap AND has same speaker THEN merge + if is_overlapped(sseg[2], next_sseg[1]) and sseg[3] == next_sseg[3]: + sseg[2] = next_sseg[2] # just update the end time + # This is important. For the last sseg, if it is the same speaker the merge + # Make sure we don't append the last segment once more. Hence, set FLAG=True + if i == len(lol) - 1: + flag = True + new_lol.append(sseg) + else: + new_lol.append(sseg) + sseg = next_sseg + + # Add last segment only when it was skipped earlier. + if flag is False: + new_lol.append(lol[-1]) + + return new_lol + + +def distribute_overlap(lol): + """ + Distributes the overlapped speech equally among the adjacent segments + with different speakers. + + Arguments + --------- + lol : list of list + It has each list structure as [rec_id, seg_start, seg_end, spkr_id]. + + Returns + ------- + new_lol : list of list + It contains the overlapped part equally divided among the adjacent + segments with different speaker IDs. + + Example + ------- + >>> import diarization as diar + >>> lol = [['r1', 5.5, 9.0, 's1'], + ... ['r1', 8.0, 11.0, 's2'], + ... ['r1', 11.5, 13.0, 's2'], + ... ['r1', 12.0, 15.0, 's1']] + >>> diar.distribute_overlap(lol) + [['r1', 5.5, 8.5, 's1'], ['r1', 8.5, 11.0, 's2'], ['r1', 11.5, 12.5, 's2'], ['r1', 12.5, 15.0, 's1']] + """ + + new_lol = [] + sseg = lol[0] + + # Add first sub-segment here to avoid error at: "if new_lol[-1] != sseg:" when new_lol is empty + # new_lol.append(sseg) + + for i in range(1, len(lol)): + next_sseg = lol[i] + # No need to check if they are different speakers. + # Because if segments are overlapped then they always have different speakers. + # This is because similar speaker's adjacent sub-segments are already merged by "merge_ssegs_same_speaker()" + + if is_overlapped(sseg[2], next_sseg[1]): + + # Get overlap duration. + # Now this overlap will be divided equally between adjacent segments. + overlap = sseg[2] - next_sseg[1] + + # Update end time of old seg + sseg[2] = sseg[2] - (overlap / 2.0) + + # Update start time of next seg + next_sseg[1] = next_sseg[1] + (overlap / 2.0) + + if len(new_lol) == 0: + # For first sub-segment entry + new_lol.append(sseg) + else: + # To avoid duplicate entries + if new_lol[-1] != sseg: + new_lol.append(sseg) + + # Current sub-segment is next sub-segment + sseg = next_sseg + + else: + # For the first sseg + if len(new_lol) == 0: + new_lol.append(sseg) + else: + # To avoid duplicate entries + if new_lol[-1] != sseg: + new_lol.append(sseg) + + # Update the current sub-segment + sseg = next_sseg + + # Add the remaining last sub-segment + new_lol.append(next_sseg) + + return new_lol + + +def write_rttm(segs_list, out_rttm_file): + """ + Writes the segment list in RTTM format (A standard NIST format). + + Arguments + --------- + segs_list : list of list + Each list contains [rec_id, seg_start, seg_end, spkr_id]. + out_rttm_file : str + Path of the output RTTM file. + """ + + rttm = [] + rec_id = segs_list[0][0] + + for seg in segs_list: + new_row = [ + "SPEAKER", + rec_id, + "0", + str(round(seg[1], 4)), + str(round(seg[2] - seg[1], 4)), + "", + "", + seg[3], + "", + "", + ] + rttm.append(new_row) + + with open(out_rttm_file, "w") as f: + for row in rttm: + line_str = " ".join(row) + f.write("%s\n" % line_str) + + +def do_AHC(diary_obj, out_rttm_file, rec_id, k_oracle=4, p_val=0.3): + """ + Performs Agglomerative Hierarchical Clustering on embeddings. + + Arguments + --------- + diary_obj : EmbeddingMeta type + Contains embeddings in diary_obj.stats and segment IDs in diary_obj.segset. + out_rttm_file : str + Path of the output RTTM file. + rec_id : str + Recording ID for the recording under processing. + k : int + Number of speaker (None, if it has to be estimated). + pval : float + `pval` for prunning affinity matrix. Used only when number of speakers + are unknown. Note that this is just for experiment. Prefer Spectral clustering + for better clustering results. + """ + + from sklearn.cluster import AgglomerativeClustering + + # p_val is the threshold_val (for AHC) + diary_obj.norm_stats() + + # processing + if k_oracle is not None: + num_of_spk = k_oracle + + clustering = AgglomerativeClustering( + n_clusters=num_of_spk, + affinity="cosine", + linkage="average", ).fit(diary_obj.stats) + labels = clustering.labels_ + + else: + # Estimate num of using max eigen gap with `cos` affinity matrix. + # This is just for experimentation. + clustering = AgglomerativeClustering( + n_clusters=None, + affinity="cosine", + linkage="average", + distance_threshold=p_val, ).fit(diary_obj.stats) + labels = clustering.labels_ + + # Convert labels to speaker boundaries + subseg_ids = diary_obj.segset + lol = [] + + for i in range(labels.shape[0]): + spkr_id = rec_id + "_" + str(labels[i]) + + sub_seg = subseg_ids[i] + + splitted = sub_seg.rsplit("_", 2) + rec_id = str(splitted[0]) + sseg_start = float(splitted[1]) + sseg_end = float(splitted[2]) + + a = [rec_id, sseg_start, sseg_end, spkr_id] + lol.append(a) + + # Sorting based on start time of sub-segment + lol.sort(key=lambda x: float(x[1])) + + # Merge and split in 2 simple steps: (i) Merge sseg of same speakers then (ii) split different speakers + # Step 1: Merge adjacent sub-segments that belong to same speaker (or cluster) + lol = merge_ssegs_same_speaker(lol) + + # Step 2: Distribute duration of adjacent overlapping sub-segments belonging to different speakers (or cluster) + # Taking mid-point as the splitting time location. + lol = distribute_overlap(lol) + + # logger.info("Completed diarizing " + rec_id) + write_rttm(lol, out_rttm_file) + + +def do_spec_clustering(diary_obj, out_rttm_file, rec_id, k, pval, affinity_type, + n_neighbors): + """ + Performs spectral clustering on embeddings. This function calls specific + clustering algorithms as per affinity. + + Arguments + --------- + diary_obj : EmbeddingMeta type + Contains embeddings in diary_obj.stats and segment IDs in diary_obj.segset. + out_rttm_file : str + Path of the output RTTM file. + rec_id : str + Recording ID for the recording under processing. + k : int + Number of speaker (None, if it has to be estimated). + pval : float + `pval` for prunning affinity matrix. + affinity_type : str + Type of similarity to be used to get affinity matrix (cos or nn). + """ + + if affinity_type == "cos": + clust_obj = SpecClustUnorm(min_num_spkrs=2, max_num_spkrs=10) + k_oracle = k # use it only when oracle num of speakers + clust_obj.do_spec_clust(diary_obj.stats, k_oracle, pval) + labels = clust_obj.labels_ + else: + clust_obj = SpecCluster( + n_clusters=k, + assign_labels="kmeans", + random_state=1234, + affinity="nearest_neighbors", ) + clust_obj.perform_sc(diary_obj.stats, n_neighbors) + labels = clust_obj.labels_ + + # Convert labels to speaker boundaries + subseg_ids = diary_obj.segset + lol = [] + + for i in range(labels.shape[0]): + spkr_id = rec_id + "_" + str(labels[i]) + + sub_seg = subseg_ids[i] + + splitted = sub_seg.rsplit("_", 2) + rec_id = str(splitted[0]) + sseg_start = float(splitted[1]) + sseg_end = float(splitted[2]) + + a = [rec_id, sseg_start, sseg_end, spkr_id] + lol.append(a) + + # Sorting based on start time of sub-segment + lol.sort(key=lambda x: float(x[1])) + + # Merge and split in 2 simple steps: (i) Merge sseg of same speakers then (ii) split different speakers + # Step 1: Merge adjacent sub-segments that belong to same speaker (or cluster) + lol = merge_ssegs_same_speaker(lol) + + # Step 2: Distribute duration of adjacent overlapping sub-segments belonging to different speakers (or cluster) + # Taking mid-point as the splitting time location. + lol = distribute_overlap(lol) + + # logger.info("Completed diarizing " + rec_id) + write_rttm(lol, out_rttm_file) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + prog='python diarization.py --backend AHC', description='diarizing') + parser.add_argument( + '--sys_rttm_dir', + required=False, + help='Directory to store system RTTM files') + parser.add_argument( + '--ref_rttm_dir', + required=False, + help='Directory to store reference RTTM files') + parser.add_argument( + '--backend', default="AHC", help='type of backend, AHC or SC or kmeans') + parser.add_argument( + '--oracle_n_spkrs', + default=True, + type=strtobool, + help='Oracle num of speakers') + parser.add_argument( + '--mic_type', + default="Mix-Headset", + help='Type of microphone to be used') + parser.add_argument( + '--affinity', default="cos", help='affinity matrix, cos or nn') + parser.add_argument( + '--max_subseg_dur', + default=3.0, + type=float, + help='Duration in seconds of a subsegments to be prepared from larger segments' + ) + parser.add_argument( + '--overlap', + default=1.5, + type=float, + help='Overlap duration in seconds between adjacent subsegments') + + args = parser.parse_args() + + pval = 0.3 + rec_id = "utt0001" + n_neighbors = 10 + out_rttm_file = "./out.rttm" + + embeddings = np.empty(shape=[0, 32], dtype=np.float64) + segset = [] + + for i in range(10): + seg = [rec_id + "_" + str(i) + "_" + str(i + 1)] + segset = segset + seg + emb = np.random.rand(1, 32) + embeddings = np.concatenate((embeddings, emb), axis=0) + + segset = np.array(segset, dtype="|O") + stat_obj = EmbeddingMeta(segset, embeddings) + if args.oracle_n_spkrs is True: + num_spkrs = 2 + + if args.backend == "SC": + print("begin SC ") + do_spec_clustering( + stat_obj, + out_rttm_file, + rec_id, + num_spkrs, + pval, + args.affinity, + n_neighbors, ) + if args.backend == "AHC": + print("begin AHC ") + do_AHC(stat_obj, out_rttm_file, rec_id, num_spkrs, pval) diff --git a/setup.py b/setup.py index f86758bab25..82ff6341265 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ HERE = Path(os.path.abspath(os.path.dirname(__file__))) -VERSION = '0.1.2' +VERSION = '0.2.0' base = [ "editdistance", diff --git a/speechx/.gitignore b/speechx/.gitignore new file mode 100644 index 00000000000..e0c61847075 --- /dev/null +++ b/speechx/.gitignore @@ -0,0 +1 @@ +tools/valgrind* diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index e003136a9d7..f1330d1da66 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -2,18 +2,32 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) project(paddlespeech VERSION 0.1) +set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0048.cmake") + set(CMAKE_VERBOSE_MAKEFILE on) + # set std-14 set(CMAKE_CXX_STANDARD 14) -# include file +# cmake dir +set(speechx_cmake_dir ${PROJECT_SOURCE_DIR}/cmake) + +# Modules +list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}/external) +list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}) include(FetchContent) include(ExternalProject) + # fc_patch dir set(FETCHCONTENT_QUIET off) get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}") set(FETCHCONTENT_BASE_DIR ${fc_patch}) +# compiler option +# Keep the same with openfst, -fPIC or -fpic +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g") +SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb") +SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall") ############################################################################### # Option Configurations @@ -25,91 +39,92 @@ option(TEST_DEBUG "option for debug" OFF) ############################################################################### # Include third party ############################################################################### -# #example for include third party -# FetchContent_Declare() -# # FetchContent_MakeAvailable was not added until CMake 3.14 +# example for include third party +# FetchContent_MakeAvailable was not added until CMake 3.14 # FetchContent_MakeAvailable() # include_directories() +# gflags +include(gflags) + +# glog +include(glog) + +# gtest +include(gtest) + # ABSEIL-CPP -include(FetchContent) -FetchContent_Declare( - absl - GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git" - GIT_TAG "20210324.1" -) -FetchContent_MakeAvailable(absl) +include(absl) # libsndfile -include(FetchContent) -FetchContent_Declare( - libsndfile - GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git" - GIT_TAG "1.0.31" -) -FetchContent_MakeAvailable(libsndfile) +include(libsndfile) -# gflags -FetchContent_Declare( - gflags - URL https://github.com/gflags/gflags/archive/v2.2.1.zip - URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a -) -FetchContent_MakeAvailable(gflags) -include_directories(${gflags_BINARY_DIR}/include) +# boost +# include(boost) # not work +set(boost_SOURCE_DIR ${fc_patch}/boost-src) +set(BOOST_ROOT ${boost_SOURCE_DIR}) +# #find_package(boost REQUIRED PATHS ${BOOST_ROOT}) -# glog -FetchContent_Declare( - glog - URL https://github.com/google/glog/archive/v0.4.0.zip - URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc -) -FetchContent_MakeAvailable(glog) -include_directories(${glog_BINARY_DIR}) +# Eigen +include(eigen) +find_package(Eigen3 REQUIRED) -# gtest -FetchContent_Declare(googletest - URL https://github.com/google/googletest/archive/release-1.10.0.zip - URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91 -) -FetchContent_MakeAvailable(googletest) +# Kenlm +include(kenlm) +add_dependencies(kenlm eigen boost) + +#openblas +include(openblas) # openfst -set(openfst_SOURCE_DIR ${fc_patch}/openfst-src) -set(openfst_BINARY_DIR ${fc_patch}/openfst-build) -set(openfst_PREFIX_DIR ${fc_patch}/openfst-subbuild/openfst-populate-prefix) -ExternalProject_Add(openfst - URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip - URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6 - SOURCE_DIR ${openfst_SOURCE_DIR} - BINARY_DIR ${openfst_BINARY_DIR} - CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR} - "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}" - "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}" - "LIBS=-lgflags_nothreads -lglog -lpthread" - BUILD_COMMAND make -j 4 -) +include(openfst) add_dependencies(openfst gflags glog) -link_directories(${openfst_PREFIX_DIR}/lib) -include_directories(${openfst_PREFIX_DIR}/include) -add_subdirectory(speechx) -#openblas -#set(OpenBLAS_INSTALL_PREFIX ${fc_patch}/OpenBLAS) -#set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src) -#ExternalProject_Add( -# OpenBLAS -# GIT_REPOSITORY https://github.com/xianyi/OpenBLAS -# GIT_TAG v0.3.13 -# GIT_SHALLOW TRUE -# GIT_PROGRESS TRUE -# CONFIGURE_COMMAND "" -# BUILD_IN_SOURCE TRUE -# BUILD_COMMAND make USE_LOCKING=1 USE_THREAD=0 -# INSTALL_COMMAND make PREFIX=${OpenBLAS_INSTALL_PREFIX} install -# UPDATE_DISCONNECTED TRUE -#) +# paddle lib +set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib) +set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix) +ExternalProject_Add(paddle + URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz + URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873 + PREFIX ${paddle_PREFIX_DIR} + SOURCE_DIR ${paddle_SOURCE_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" +) + +set(PADDLE_LIB ${fc_patch}/paddle-lib) +include_directories("${PADDLE_LIB}/paddle/include") +set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include") + +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib") +link_directories("${PADDLE_LIB}/paddle/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib") + +##paddle with mkl +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml") +include_directories("${MATH_LIB_PATH}/include") +set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) +set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn") +include_directories("${MKLDNN_PATH}/include") +set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) +set(EXTERNAL_LIB "-lrt -ldl -lpthread") + +set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) +set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags protobuf xxhash cryptopp + ${EXTERNAL_LIB}) + + ############################################################################### # Add local library @@ -121,4 +136,9 @@ add_subdirectory(speechx) # if dir do not have CmakeLists.txt #add_library(lib_name STATIC file.cc) #target_link_libraries(lib_name item0 item1) -#add_dependencies(lib_name depend-target) \ No newline at end of file +#add_dependencies(lib_name depend-target) + +set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx) + +add_subdirectory(speechx) +add_subdirectory(examples) \ No newline at end of file diff --git a/speechx/README.md b/speechx/README.md new file mode 100644 index 00000000000..7d73b61c6fa --- /dev/null +++ b/speechx/README.md @@ -0,0 +1,61 @@ +# SpeechX -- All in One Speech Task Inference + +## Environment + +We develop under: +* docker - registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7 +* os - Ubuntu 16.04.7 LTS +* gcc/g++ - 8.2.0 +* cmake - 3.16.0 + +> We make sure all things work fun under docker, and recommend using it to develop and deploy. + +* [How to Install Docker](https://docs.docker.com/engine/install/) +* [A Docker Tutorial for Beginners](https://docker-curriculum.com/) +* [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/overview.html) + +## Build + +1. First to launch docker container. + +``` +nvidia-docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --name=dev registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7 /bin/bash +``` + +* More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html). + +* If you want only work under cpu, please download corresponded [image](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html), and using `docker` instead `nviida-docker`. + + +2. Build `speechx` and `examples`. + +``` +pushd /path/to/speechx +./build.sh +``` + +3. Go to `examples` to have a fun. + +More details please see `README.md` under `examples`. + + +## Valgrind (Optional) + +> If using docker please check `--privileged` is set when `docker run`. + +* Fatal error at startup: `a function redirection which is mandatory for this platform-tool combination cannot be set up` +``` +apt-get install libc6-dbg +``` + +* Install + +``` +pushd tools +./setup_valgrind.sh +popd +``` + +## TODO + +* DecibelNormalizer: there is a little bit difference between offline and online db norm. The computation of online db norm read feature chunk by chunk, which causes the feature size is different with offline db norm. In normalizer.cc:73, the samples.size() is different, which causes the difference of result. diff --git a/speechx/build.sh b/speechx/build.sh new file mode 100755 index 00000000000..3e9600d538c --- /dev/null +++ b/speechx/build.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +# the build script had verified in the paddlepaddle docker image. +# please follow the instruction below to install PaddlePaddle image. +# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html + +boost_SOURCE_DIR=$PWD/fc_patch/boost-src +if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz + tar xzfv boost_1_75_0.tar.gz + mkdir -p $PWD/fc_patch + mv boost_1_75_0 ${boost_SOURCE_DIR} + cd ${boost_SOURCE_DIR} + bash ./bootstrap.sh + ./b2 + cd - + echo -e "\n" +fi + +#rm -rf build +mkdir -p build +cd build + +cmake .. -DBOOST_ROOT:STRING=${boost_SOURCE_DIR} +#cmake .. + +make -j1 + +cd - diff --git a/speechx/cmake/EnableCMP0048.cmake b/speechx/cmake/EnableCMP0048.cmake new file mode 100644 index 00000000000..1b59188fd7d --- /dev/null +++ b/speechx/cmake/EnableCMP0048.cmake @@ -0,0 +1 @@ +cmake_policy(SET CMP0048 NEW) \ No newline at end of file diff --git a/speechx/cmake/external/absl.cmake b/speechx/cmake/external/absl.cmake new file mode 100644 index 00000000000..2c5e5af5ca5 --- /dev/null +++ b/speechx/cmake/external/absl.cmake @@ -0,0 +1,16 @@ +include(FetchContent) + + +set(BUILD_SHARED_LIBS OFF) # up to you +set(BUILD_TESTING OFF) # to disable abseil test, or gtest will fail. +set(ABSL_ENABLE_INSTALL ON) # now you can enable install rules even in subproject... + +FetchContent_Declare( + absl + GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git" + GIT_TAG "20210324.1" +) +FetchContent_MakeAvailable(absl) + +set(EIGEN3_INCLUDE_DIR ${Eigen3_SOURCE_DIR}) +include_directories(${absl_SOURCE_DIR}) \ No newline at end of file diff --git a/speechx/cmake/external/boost.cmake b/speechx/cmake/external/boost.cmake new file mode 100644 index 00000000000..6bc97aad4da --- /dev/null +++ b/speechx/cmake/external/boost.cmake @@ -0,0 +1,27 @@ +include(FetchContent) +set(Boost_DEBUG ON) + +set(Boost_PREFIX_DIR ${fc_patch}/boost) +set(Boost_SOURCE_DIR ${fc_patch}/boost-src) + +FetchContent_Declare( + Boost + URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz + URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a + PREFIX ${Boost_PREFIX_DIR} + SOURCE_DIR ${Boost_SOURCE_DIR} +) + +execute_process(COMMAND bootstrap.sh WORKING_DIRECTORY ${Boost_SOURCE_DIR}) +execute_process(COMMAND b2 WORKING_DIRECTORY ${Boost_SOURCE_DIR}) + +FetchContent_MakeAvailable(Boost) + +message(STATUS "boost src dir: ${Boost_SOURCE_DIR}") +message(STATUS "boost inc dir: ${Boost_INCLUDE_DIR}") +message(STATUS "boost bin dir: ${Boost_BINARY_DIR}") + +set(BOOST_ROOT ${Boost_SOURCE_DIR}) +message(STATUS "boost root dir: ${BOOST_ROOT}") + +include_directories(${Boost_SOURCE_DIR}) \ No newline at end of file diff --git a/speechx/cmake/external/eigen.cmake b/speechx/cmake/external/eigen.cmake new file mode 100644 index 00000000000..12bd3cdf517 --- /dev/null +++ b/speechx/cmake/external/eigen.cmake @@ -0,0 +1,27 @@ +include(FetchContent) + +# update eigen to the commit id f612df27 on 03/16/2021 +set(EIGEN_PREFIX_DIR ${fc_patch}/eigen3) + +FetchContent_Declare( + Eigen3 + GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git + GIT_TAG master + PREFIX ${EIGEN_PREFIX_DIR} + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE) + +set(EIGEN_BUILD_DOC OFF) +# note: To disable eigen tests, +# you should put this code in a add_subdirectory to avoid to change +# BUILD_TESTING for your own project too since variables are directory +# scoped +set(BUILD_TESTING OFF) +set(EIGEN_BUILD_PKGCONFIG OFF) +set( OFF) +FetchContent_MakeAvailable(Eigen3) + +message(STATUS "eigen src dir: ${Eigen3_SOURCE_DIR}") +message(STATUS "eigen bin dir: ${Eigen3_BINARY_DIR}") +#include_directories(${Eigen3_SOURCE_DIR}) +#link_directories(${Eigen3_BINARY_DIR}) \ No newline at end of file diff --git a/speechx/cmake/external/gflags.cmake b/speechx/cmake/external/gflags.cmake new file mode 100644 index 00000000000..66ae47f7098 --- /dev/null +++ b/speechx/cmake/external/gflags.cmake @@ -0,0 +1,12 @@ +include(FetchContent) + +FetchContent_Declare( + gflags + URL https://github.com/gflags/gflags/archive/v2.2.1.zip + URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a +) + +FetchContent_MakeAvailable(gflags) + +# openfst need +include_directories(${gflags_BINARY_DIR}/include) \ No newline at end of file diff --git a/speechx/cmake/external/glog.cmake b/speechx/cmake/external/glog.cmake new file mode 100644 index 00000000000..dcfd86c3ed5 --- /dev/null +++ b/speechx/cmake/external/glog.cmake @@ -0,0 +1,8 @@ +include(FetchContent) +FetchContent_Declare( + glog + URL https://github.com/google/glog/archive/v0.4.0.zip + URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc +) +FetchContent_MakeAvailable(glog) +include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src) diff --git a/speechx/cmake/external/gtest.cmake b/speechx/cmake/external/gtest.cmake new file mode 100644 index 00000000000..7fe397fcb08 --- /dev/null +++ b/speechx/cmake/external/gtest.cmake @@ -0,0 +1,9 @@ +include(FetchContent) +FetchContent_Declare( + gtest + URL https://github.com/google/googletest/archive/release-1.10.0.zip + URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91 +) +FetchContent_MakeAvailable(gtest) + +include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src) \ No newline at end of file diff --git a/speechx/cmake/external/kenlm.cmake b/speechx/cmake/external/kenlm.cmake new file mode 100644 index 00000000000..17c76c3f633 --- /dev/null +++ b/speechx/cmake/external/kenlm.cmake @@ -0,0 +1,10 @@ +include(FetchContent) +FetchContent_Declare( + kenlm + GIT_REPOSITORY "https://github.com/kpu/kenlm.git" + GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a" +) +# https://github.com/kpu/kenlm/blob/master/cmake/modules/FindEigen3.cmake +set(EIGEN3_INCLUDE_DIR ${Eigen3_SOURCE_DIR}) +FetchContent_MakeAvailable(kenlm) +include_directories(${kenlm_SOURCE_DIR}) \ No newline at end of file diff --git a/speechx/cmake/external/libsndfile.cmake b/speechx/cmake/external/libsndfile.cmake new file mode 100644 index 00000000000..52d64bacd31 --- /dev/null +++ b/speechx/cmake/external/libsndfile.cmake @@ -0,0 +1,56 @@ +include(FetchContent) + +# https://github.com/pongasoft/vst-sam-spl-64/blob/master/libsndfile.cmake +# https://github.com/popojan/goban/blob/master/CMakeLists.txt#L38 +# https://github.com/ddiakopoulos/libnyquist/blob/master/CMakeLists.txt + +if(LIBSNDFILE_ROOT_DIR) + # instructs FetchContent to not download or update but use the location instead + set(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE ${LIBSNDFILE_ROOT_DIR}) +else() + set(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE "") +endif() + +set(LIBSNDFILE_GIT_REPO "https://github.com/libsndfile/libsndfile.git" CACHE STRING "libsndfile git repository url" FORCE) +set(LIBSNDFILE_GIT_TAG 1.0.31 CACHE STRING "libsndfile git tag" FORCE) + +FetchContent_Declare(libsndfile + GIT_REPOSITORY ${LIBSNDFILE_GIT_REPO} + GIT_TAG ${LIBSNDFILE_GIT_TAG} + GIT_CONFIG advice.detachedHead=false +# GIT_SHALLOW true + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) + +FetchContent_GetProperties(libsndfile) +if(NOT libsndfile_POPULATED) + if(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE) + message(STATUS "Using libsndfile from local ${FETCHCONTENT_SOURCE_DIR_LIBSNDFILE}") + else() + message(STATUS "Fetching libsndfile ${LIBSNDFILE_GIT_REPO}/tree/${LIBSNDFILE_GIT_TAG}") + endif() + FetchContent_Populate(libsndfile) +endif() + +set(LIBSNDFILE_ROOT_DIR ${libsndfile_SOURCE_DIR}) +set(LIBSNDFILE_INCLUDE_DIR "${libsndfile_BINARY_DIR}/src") + +function(libsndfile_build) + option(BUILD_PROGRAMS "Build programs" OFF) + option(BUILD_EXAMPLES "Build examples" OFF) + option(BUILD_TESTING "Build examples" OFF) + option(ENABLE_CPACK "Enable CPack support" OFF) + option(ENABLE_PACKAGE_CONFIG "Generate and install package config file" OFF) + option(BUILD_REGTEST "Build regtest" OFF) + # finally we include libsndfile itself + add_subdirectory(${libsndfile_SOURCE_DIR} ${libsndfile_BINARY_DIR} EXCLUDE_FROM_ALL) + # copying .hh for c++ support + #file(COPY "${libsndfile_SOURCE_DIR}/src/sndfile.hh" DESTINATION ${LIBSNDFILE_INCLUDE_DIR}) +endfunction() + +libsndfile_build() + +include_directories(${LIBSNDFILE_INCLUDE_DIR}) \ No newline at end of file diff --git a/speechx/cmake/external/openblas.cmake b/speechx/cmake/external/openblas.cmake new file mode 100644 index 00000000000..3c202f7f689 --- /dev/null +++ b/speechx/cmake/external/openblas.cmake @@ -0,0 +1,37 @@ +include(FetchContent) + +set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src) +set(OpenBLAS_PREFIX ${fc_patch}/OpenBLAS-prefix) + +# ###################################################################################################################### +# OPENBLAS https://github.com/lattice/quda/blob/develop/CMakeLists.txt#L575 +# ###################################################################################################################### +enable_language(Fortran) +#TODO: switch to CPM +include(GNUInstallDirs) +ExternalProject_Add( + OPENBLAS + GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git + GIT_TAG v0.3.10 + GIT_SHALLOW YES + PREFIX ${OpenBLAS_PREFIX} + SOURCE_DIR ${OpenBLAS_SOURCE_DIR} + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX= + CMAKE_GENERATOR "Unix Makefiles") + + +# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition +ExternalProject_Get_Property(OPENBLAS INSTALL_DIR) +set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR}) +add_library(openblas STATIC IMPORTED) +add_dependencies(openblas OPENBLAS) +set_target_properties(openblas PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES Fortran) +# ${CMAKE_INSTALL_LIBDIR} lib +set_target_properties(openblas PROPERTIES IMPORTED_LOCATION ${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libopenblas.a) + + +# https://cmake.org/cmake/help/latest/command/install.html?highlight=cmake_install_libdir#installing-targets +# ${CMAKE_INSTALL_LIBDIR} lib +# ${CMAKE_INSTALL_INCLUDEDIR} include +link_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}) +include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}) \ No newline at end of file diff --git a/speechx/cmake/external/openfst.cmake b/speechx/cmake/external/openfst.cmake new file mode 100644 index 00000000000..07abb18e81d --- /dev/null +++ b/speechx/cmake/external/openfst.cmake @@ -0,0 +1,19 @@ +include(FetchContent) +set(openfst_SOURCE_DIR ${fc_patch}/openfst-src) +set(openfst_BINARY_DIR ${fc_patch}/openfst-build) + +ExternalProject_Add(openfst + URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip + URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6 +# #PREFIX ${openfst_PREFIX_DIR} +# SOURCE_DIR ${openfst_SOURCE_DIR} +# BINARY_DIR ${openfst_BINARY_DIR} + CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR} + "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}" + "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}" + "LIBS=-lgflags_nothreads -lglog -lpthread" + COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR} + BUILD_COMMAND make -j 4 +) +link_directories(${openfst_PREFIX_DIR}/lib) +include_directories(${openfst_PREFIX_DIR}/include) diff --git a/speechx/examples/.gitignore b/speechx/examples/.gitignore new file mode 100644 index 00000000000..b7075fa56c3 --- /dev/null +++ b/speechx/examples/.gitignore @@ -0,0 +1,2 @@ +*.ark +paddle_asr_model/ diff --git a/speechx/examples/.gitkeep b/speechx/examples/.gitkeep deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/speechx/examples/CMakeLists.txt b/speechx/examples/CMakeLists.txt new file mode 100644 index 00000000000..ef0a72b8838 --- /dev/null +++ b/speechx/examples/CMakeLists.txt @@ -0,0 +1,5 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_subdirectory(feat) +add_subdirectory(nnet) +add_subdirectory(decoder) diff --git a/speechx/examples/README.md b/speechx/examples/README.md new file mode 100644 index 00000000000..941c4272d9a --- /dev/null +++ b/speechx/examples/README.md @@ -0,0 +1,16 @@ +# Examples + +* decoder - online decoder to work as offline +* feat - mfcc, linear +* nnet - ds2 nn + +## How to run + +`run.sh` is the entry point. + +Example to play `decoder`: + +``` +pushd decoder +bash run.sh +``` diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt new file mode 100644 index 00000000000..4bd5c6cf066 --- /dev/null +++ b/speechx/examples/decoder/CMakeLists.txt @@ -0,0 +1,5 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc) +target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) diff --git a/speechx/examples/decoder/offline_decoder_main.cc b/speechx/examples/decoder/offline_decoder_main.cc new file mode 100644 index 00000000000..44127c73b4e --- /dev/null +++ b/speechx/examples/decoder/offline_decoder_main.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor, repalce with gtest + +#include "base/flags.h" +#include "base/log.h" +#include "decoder/ctc_beam_search_decoder.h" +#include "frontend/raw_audio.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/paddle_nnet.h" + +DEFINE_string(feature_respecifier, "", "test feature rspecifier"); +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); +DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); +DEFINE_string(lm_path, "lm.klm", "language model"); + + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialBaseFloatMatrixReader feature_reader( + FLAGS_feature_respecifier); + std::string model_graph = FLAGS_model_path; + std::string model_params = FLAGS_param_path; + std::string dict_file = FLAGS_dict_file; + std::string lm_path = FLAGS_lm_path; + + int32 num_done = 0, num_err = 0; + + ppspeech::CTCBeamSearchOptions opts; + opts.dict_file = dict_file; + opts.lm_path = lm_path; + ppspeech::CTCBeamSearch decoder(opts); + + ppspeech::ModelOptions model_opts; + model_opts.model_path = model_graph; + model_opts.params_path = model_params; + std::shared_ptr nnet( + new ppspeech::PaddleNnet(model_opts)); + std::shared_ptr raw_data( + new ppspeech::RawDataCache()); + std::shared_ptr decodable( + new ppspeech::Decodable(nnet, raw_data)); + + int32 chunk_size = 35; + decoder.InitDecoder(); + + for (; !feature_reader.Done(); feature_reader.Next()) { + string utt = feature_reader.Key(); + const kaldi::Matrix feature = feature_reader.Value(); + raw_data->SetDim(feature.NumCols()); + int32 row_idx = 0; + int32 num_chunks = feature.NumRows() / chunk_size; + for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + kaldi::Vector feature_chunk(chunk_size * + feature.NumCols()); + for (int row_id = 0; row_id < chunk_size; ++row_id) { + kaldi::SubVector tmp(feature, row_idx); + kaldi::SubVector f_chunk_tmp( + feature_chunk.Data() + row_id * feature.NumCols(), + feature.NumCols()); + f_chunk_tmp.CopyFromVec(tmp); + row_idx++; + } + raw_data->Accept(feature_chunk); + if (chunk_idx == num_chunks - 1) { + raw_data->SetFinished(); + } + decoder.AdvanceDecode(decodable); + } + std::string result; + result = decoder.GetFinalBestPath(); + KALDI_LOG << " the result of " << utt << " is " << result; + decodable->Reset(); + decoder.Reset(); + ++num_done; + } + + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/examples/decoder/path.sh b/speechx/examples/decoder/path.sh new file mode 100644 index 00000000000..7b4b7545b38 --- /dev/null +++ b/speechx/examples/decoder/path.sh @@ -0,0 +1,14 @@ +# This contains the locations of binarys build required for running the examples. + +SPEECHX_ROOT=$PWD/../.. +SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples + +SPEECHX_TOOLS=$SPEECHX_ROOT/tools +TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin + +[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } + +export LC_AL=C + +SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/decoder/run.sh b/speechx/examples/decoder/run.sh new file mode 100755 index 00000000000..fc5e9182463 --- /dev/null +++ b/speechx/examples/decoder/run.sh @@ -0,0 +1,40 @@ +#!/bin/bash +set +x +set -e + +. path.sh + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + + +# 2. download model +if [ ! -d ../paddle_asr_model ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz + tar xzfv paddle_asr_model.tar.gz + mv ./paddle_asr_model ../ + # produce wav scp + echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp +fi + +model_dir=../paddle_asr_model +feat_wspecifier=./feats.ark +cmvn=./cmvn.ark + +# 3. run feat +linear_spectrogram_main \ + --wav_rspecifier=scp:$model_dir/wav.scp \ + --feature_wspecifier=ark,t:$feat_wspecifier \ + --cmvn_write_path=$cmvn + +# 4. run decoder +offline_decoder_main \ + --feature_respecifier=ark:$feat_wspecifier \ + --model_path=$model_dir/avg_1.jit.pdmodel \ + --param_path=$model_dir/avg_1.jit.pdparams \ + --dict_file=$model_dir/vocab.txt \ + --lm_path=$model_dir/avg_1.jit.klm \ No newline at end of file diff --git a/speechx/examples/decoder/valgrind.sh b/speechx/examples/decoder/valgrind.sh new file mode 100755 index 00000000000..14efe0ba42b --- /dev/null +++ b/speechx/examples/decoder/valgrind.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# this script is for memory check, so please run ./run.sh first. + +set +x +set -e + +. ./path.sh + +if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then + echo "please install valgrind in the speechx tools dir.\n" + exit 1 +fi + +model_dir=../paddle_asr_model +feat_wspecifier=./feats.ark +cmvn=./cmvn.ark + +valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \ + offline_decoder_main \ + --feature_respecifier=ark:$feat_wspecifier \ + --model_path=$model_dir/avg_1.jit.pdmodel \ + --param_path=$model_dir/avg_1.jit.pdparams \ + --dict_file=$model_dir/vocab.txt \ + --lm_path=$model_dir/avg_1.jit.klm + diff --git a/speechx/examples/feat/CMakeLists.txt b/speechx/examples/feat/CMakeLists.txt new file mode 100644 index 00000000000..b8f516afb5a --- /dev/null +++ b/speechx/examples/feat/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + + +add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc) +target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(mfcc-test kaldi-mfcc) + +add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc) +target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog) \ No newline at end of file diff --git a/speechx/examples/feat/feature-mfcc-test.cc b/speechx/examples/feat/feature-mfcc-test.cc new file mode 100644 index 00000000000..ae32aba9e6a --- /dev/null +++ b/speechx/examples/feat/feature-mfcc-test.cc @@ -0,0 +1,720 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// feat/feature-mfcc-test.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include + +#include "base/kaldi-math.h" +#include "feat/feature-mfcc.h" +#include "feat/wave-reader.h" +#include "matrix/kaldi-matrix-inl.h" + +using namespace kaldi; + + +static void UnitTestReadWave() { + std::cout << "=== UnitTestReadWave() ===\n"; + + Vector v, v2; + + std::cout << "<<<=== Reading waveform\n"; + + { + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + const Matrix data(wave.Data()); + KALDI_ASSERT(data.NumRows() == 1); + v.Resize(data.NumCols()); + v.CopyFromVec(data.Row(0)); + } + + std::cout + << "<<<=== Reading Vector waveform, prepared by matlab\n"; + std::ifstream input("test_data/test_matlab.ascii"); + KALDI_ASSERT(input.good()); + v2.Read(input, false); + input.close(); + + std::cout + << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n"; + KALDI_ASSERT(v.Dim() == v2.Dim()); + for (int32 i = 0; i < v.Dim(); i++) { + KALDI_ASSERT(v(i) == v2(i)); + } + std::cout << "<<<=== Comparing done\n"; + + // std::cout << "== The Waveform Samples == \n"; + // std::cout << v; + + std::cout << "Test passed :)\n\n"; +} + + +/** + */ +static void UnitTestSimple() { + std::cout << "=== UnitTestSimple() ===\n"; + + Vector v(100000); + Matrix m; + + // init with noise + for (int32 i = 0; i < v.Dim(); i++) { + v(i) = (abs(i * 433024253) % 65535) - (65535 / 2); + } + + std::cout << "<<<=== Just make sure it runs... Nothing is compared\n"; + // the parametrization object + MfccOptions op; + // trying to have same opts as baseline. + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "rectangular"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.mel_opts.htk_mode = true; + op.htk_compat = true; + + Mfcc mfcc(op); + // use default parameters + + // compute mfccs. + mfcc.Compute(v, 1.0, &m); + + // possibly dump + // std::cout << "== Output features == \n" << m; + std::cout << "Test passed :)\n\n"; +} + + +static void UnitTestHTKCompare1() { + std::cout << "=== UnitTestHTKCompare1() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.1", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.mel_opts.htk_mode = true; + op.htk_compat = true; + op.use_energy = false; // C0 not energy. + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (i_old != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.1", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.1"); +} + + +static void UnitTestHTKCompare2() { + std::cout << "=== UnitTestHTKCompare2() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.2", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.mel_opts.htk_mode = true; + op.htk_compat = true; + op.use_energy = true; // Use energy. + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (i_old != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.2", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.2"); +} + + +static void UnitTestHTKCompare3() { + std::cout << "=== UnitTestHTKCompare3() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.3", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.htk_compat = true; + op.use_energy = true; // Use energy. + op.mel_opts.low_freq = 20.0; + // op.mel_opts.debug_mel = true; + op.mel_opts.htk_mode = true; + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.3", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.3"); +} + + +static void UnitTestHTKCompare4() { + std::cout << "=== UnitTestHTKCompare4() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.4", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.htk_compat = true; + op.use_energy = true; // Use energy. + op.mel_opts.htk_mode = true; + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.4", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.4"); +} + + +static void UnitTestHTKCompare5() { + std::cout << "=== UnitTestHTKCompare5() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.5", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.htk_compat = true; + op.use_energy = true; // Use energy. + op.mel_opts.low_freq = 0.0; + op.mel_opts.vtln_low = 100.0; + op.mel_opts.vtln_high = 7500.0; + op.mel_opts.htk_mode = true; + + BaseFloat vtln_warp = + 1.1; // our approach identical to htk for warp factor >1, + // differs slightly for higher mel bins if warp_factor <0.9 + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.5", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.5"); +} + +static void UnitTestHTKCompare6() { + std::cout << "=== UnitTestHTKCompare6() ===\n"; + + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.6", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.97; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.num_bins = 24; + op.mel_opts.low_freq = 125.0; + op.mel_opts.high_freq = 7800.0; + op.htk_compat = true; + op.use_energy = false; // C0 not energy. + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.6", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.6"); +} + +void UnitTestVtln() { + // Test the function VtlnWarpFreq. + BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20, + vtln_high_cutoff = 7400; + + for (size_t i = 0; i < 100; i++) { + BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2; + AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + warp_factor, + freq), + freq / warp_factor); + + AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + warp_factor, + low_freq), + low_freq); + AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + warp_factor, + high_freq), + high_freq); + BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(), + freq3 = freq2 + + (high_freq - freq2) * RandUniform(); // freq3>=freq2 + BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + warp_factor, + freq2); + BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + warp_factor, + freq3); + KALDI_ASSERT(w3 >= w2); // increasing function. + BaseFloat w3dash = MelBanks::VtlnWarpFreq( + vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3); + AssertEqual(w3dash, freq3); + } +} + +static void UnitTestFeat() { + UnitTestVtln(); + UnitTestReadWave(); + UnitTestSimple(); + UnitTestHTKCompare1(); + UnitTestHTKCompare2(); + // commenting out this one as it doesn't compare right now I normalized + // the way the FFT bins are treated (removed offset of 0.5)... this seems + // to relate to the way frequency zero behaves. + UnitTestHTKCompare3(); + UnitTestHTKCompare4(); + UnitTestHTKCompare5(); + UnitTestHTKCompare6(); + std::cout << "Tests succeeded.\n"; +} + + +int main() { + try { + for (int i = 0; i < 5; i++) UnitTestFeat(); + std::cout << "Tests succeeded.\n"; + return 0; + } catch (const std::exception &e) { + std::cerr << e.what(); + return 1; + } +} diff --git a/speechx/examples/feat/linear_spectrogram_main.cc b/speechx/examples/feat/linear_spectrogram_main.cc new file mode 100644 index 00000000000..9ed4d6f9344 --- /dev/null +++ b/speechx/examples/feat/linear_spectrogram_main.cc @@ -0,0 +1,248 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor, repalce with gtest + +#include "frontend/linear_spectrogram.h" +#include "base/flags.h" +#include "base/log.h" +#include "frontend/feature_cache.h" +#include "frontend/feature_extractor_interface.h" +#include "frontend/normalizer.h" +#include "frontend/raw_audio.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(wav_rspecifier, "", "test wav scp path"); +DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); +DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); + + +std::vector mean_{ + -13730251.531853663, -12982852.199316509, -13673844.299583456, + -13089406.559646806, -12673095.524938712, -12823859.223276224, + -13590267.158903603, -14257618.467152044, -14374605.116185192, + -14490009.21822485, -14849827.158924166, -15354435.470563512, + -15834149.206532761, -16172971.985514281, -16348740.496746974, + -16423536.699409386, -16556246.263649225, -16744088.772748645, + -16916184.08510357, -17054034.840031497, -17165612.509455364, + -17255955.470915023, -17322572.527648456, -17408943.862033736, + -17521554.799865916, -17620623.254924215, -17699792.395918526, + -17723364.411134344, -17741483.4433254, -17747426.888704527, + -17733315.928209435, -17748780.160905756, -17808336.883775543, + -17895918.671983004, -18009812.59173023, -18098188.66548325, + -18195798.958462656, -18293617.62980999, -18397432.92077201, + -18505834.787318766, -18585451.8100908, -18652438.235649142, + -18700960.306275308, -18734944.58792185, -18737426.313365128, + -18735347.165987637, -18738813.444170244, -18737086.848890636, + -18731576.2474336, -18717405.44095871, -18703089.25545657, + -18691014.546456724, -18692460.568905357, -18702119.628629155, + -18727710.621126678, -18761582.72034647, -18806745.835547544, + -18850674.8692112, -18884431.510951452, -18919999.992506847, + -18939303.799078144, -18952946.273760635, -18980289.22996379, + -19011610.17803294, -19040948.61805145, -19061021.429847397, + -19112055.53768819, -19149667.414264943, -19201127.05091321, + -19270250.82564605, -19334606.883057203, -19390513.336589377, + -19444176.259208687, -19502755.000038862, -19544333.014549147, + -19612668.183176614, -19681902.19006569, -19771969.951249883, + -19873329.723376893, -19996752.59235844, -20110031.131400537, + -20231658.612529557, -20319378.894054495, -20378534.45718066, + -20413332.089584175, -20438147.844177883, -20443710.248040095, + -20465457.02238927, -20488610.969337028, -20516295.16424432, + -20541423.795738827, -20553192.874953747, -20573605.50701977, + -20577871.61936797, -20571807.008916274, -20556242.38912231, + -20542199.30819195, -20521239.063551214, -20519150.80004532, + -20527204.80248933, -20536933.769257784, -20543470.522332076, + -20549700.089992985, -20551525.24958494, -20554873.406493705, + -20564277.65794227, -20572211.740052115, -20574305.69550465, + -20575494.450104576, -20567092.577932164, -20549302.929608088, + -20545445.11878376, -20546625.326603737, -20549190.03499401, + -20554824.947828256, -20568341.378989458, -20577582.331383612, + -20577980.519402675, -20566603.03458152, -20560131.592262644, + -20552166.469060015, -20549063.06763577, -20544490.562339947, + -20539817.82346569, -20528747.715731595, -20518026.24576161, + -20510977.844974525, -20506874.36087992, -20506731.11977665, + -20510482.133420516, -20507760.92101862, -20494644.834457114, + -20480107.89304893, -20461312.091867123, -20442941.75080173, + -20426123.02834838, -20424607.675283, -20426810.369107097, + -20434024.50097819, -20437404.75544205, -20447688.63916367, + -20460893.335563846, -20482922.735127095, -20503610.119434915, + -20527062.76448319, -20557830.035128627, -20593274.72068722, + -20632528.452965066, -20673637.471334763, -20733106.97143075, + -20842921.0447562, -21054357.83621519, -21416569.534189366, + -21978460.272811692, -22753170.052172784, -23671344.10563395, + -24613499.293358143, -25406477.12230188, -25884377.82156489, + -26049040.62791664, -26996879.104431007}; +std::vector variance_{ + 213747175.10846674, 188395815.34302503, 212706429.10966414, + 199109025.81461075, 189235901.23864496, 194901336.53253657, + 217481594.29306737, 238689869.12327808, 243977501.24115244, + 248479623.6431067, 259766741.47116545, 275516766.7790273, + 291271202.3691234, 302693239.8220509, 308627358.3997694, + 311143911.38788426, 315446105.07731867, 321705430.9341829, + 327458907.4659941, 332245072.43223983, 336251717.5935284, + 339694069.7639722, 342188204.4322228, 345587110.31313115, + 349903086.2875232, 353660214.20643026, 356700344.5270885, + 357665362.3529641, 358493352.05658793, 358857951.620328, + 358375239.52774596, 358899733.6342954, 361051818.3511561, + 364361716.05025816, 368750322.3771452, 372047800.6462831, + 375655861.1349018, 379358519.1980013, 383327605.3935181, + 387458599.282341, 390434692.3406868, 392994486.35057056, + 394874418.04603153, 396230525.79763395, 396365592.0414835, + 396334819.8242737, 396488353.19250053, 396438877.00744957, + 396197980.4459586, 395590921.6672991, 395001107.62072515, + 394528291.7318225, 394593110.424006, 395018405.59353715, + 396110577.5415993, 397506704.0371068, 399400197.4657644, + 401243568.2468382, 402687134.7805103, 404136047.2872507, + 404883170.001883, 405522253.219517, 406660365.3626476, + 407919346.0991902, 409045348.5384909, 409759588.7889818, + 411974821.8564483, 413489718.78201455, 415535392.56684107, + 418466481.97674364, 421104678.35678065, 423405392.5200779, + 425550570.40798235, 427929423.9579701, 429585274.253478, + 432368493.55181056, 435193587.13513297, 438886855.20476013, + 443058876.8633751, 448181232.5093362, 452883835.6332396, + 458056721.77926534, 461816531.22735566, 464363620.1970998, + 465886343.5057493, 466928872.0651, 467180536.42647296, + 468111848.70714295, 469138695.3071312, 470378429.6930793, + 471517958.7132626, 472109050.4262365, 473087417.0177867, + 473381322.04648733, 473220195.85483915, 472666071.8998819, + 472124669.87879956, 471298571.411737, 471251033.2902761, + 471672676.43128747, 472177147.2193172, 472572361.7711908, + 472968783.7751127, 473156295.4164052, 473398034.82676554, + 473897703.5203811, 474328271.33112127, 474452670.98002136, + 474549003.99284613, 474252887.13567275, 473557462.909069, + 473483385.85193115, 473609738.04855174, 473746944.82085115, + 474016729.91696435, 474617321.94138587, 475045097.237122, + 475125402.586558, 474664112.9824912, 474426247.5800283, + 474104075.42796475, 473978219.7273978, 473773171.7798875, + 473578534.69508696, 473102924.16904145, 472651240.5232615, + 472374383.1810912, 472209479.6956096, 472202298.8921673, + 472370090.76781124, 472220933.99374026, 471625467.37106377, + 470994646.51883453, 470182428.9637543, 469348211.5939578, + 468570387.4467277, 468540442.7225135, 468672018.90414184, + 468994346.9533251, 469138757.58201426, 469553915.95710236, + 470134523.38582784, 471082421.62055486, 471962316.51804745, + 472939745.1708408, 474250621.5944825, 475773933.43199486, + 477465399.71087736, 479218782.61382693, 481752299.7930922, + 486608947.8984568, 496119403.2067917, 512730085.5704984, + 539048915.2641417, 576285298.3548826, 621610270.2240586, + 669308196.4436442, 710656993.5957186, 736344437.3725077, + 745481288.0241544, 801121432.9925804}; +int count_ = 912592; + +void WriteMatrix() { + kaldi::Matrix cmvn_stats(2, mean_.size() + 1); + for (size_t idx = 0; idx < mean_.size(); ++idx) { + cmvn_stats(0, idx) = mean_[idx]; + cmvn_stats(1, idx) = variance_[idx]; + } + cmvn_stats(0, mean_.size()) = count_; + kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true); +} + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); + WriteMatrix(); + + // test feature linear_spectorgram: wave --> decibel_normalizer --> hanning + // window -->linear_spectrogram --> cmvn + int32 num_done = 0, num_err = 0; + // std::unique_ptr data_source(new + // ppspeech::RawDataCache()); + std::unique_ptr data_source( + new ppspeech::RawAudioCache()); + + ppspeech::LinearSpectrogramOptions opt; + opt.frame_opts.frame_length_ms = 20; + opt.frame_opts.frame_shift_ms = 10; + ppspeech::DecibelNormalizerOptions db_norm_opt; + std::unique_ptr base_feature_extractor( + new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source))); + + std::unique_ptr linear_spectrogram( + new ppspeech::LinearSpectrogram(opt, + std::move(base_feature_extractor))); + + std::unique_ptr cmvn( + new ppspeech::CMVN(FLAGS_cmvn_write_path, + std::move(linear_spectrogram))); + + ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); + + float streaming_chunk = 0.36; + int sample_rate = 16000; + int chunk_sample_size = streaming_chunk * sample_rate; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + int sample_offset = 0; + std::vector> feats; + int feature_rows = 0; + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + kaldi::Vector features; + feature_cache.Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + feature_cache.SetFinished(); + } + feature_cache.Read(&features); + if (features.Dim() == 0) break; + + feats.push_back(features); + sample_offset += cur_chunk_size; + feature_rows += features.Dim() / feature_cache.Dim(); + } + + int cur_idx = 0; + kaldi::Matrix features(feature_rows, + feature_cache.Dim()); + for (auto feat : feats) { + int num_rows = feat.Dim() / feature_cache.Dim(); + for (int row_idx = 0; row_idx < num_rows; ++row_idx) { + for (size_t col_idx = 0; col_idx < feature_cache.Dim(); + ++col_idx) { + features(cur_idx, col_idx) = + feat(row_idx * feature_cache.Dim() + col_idx); + } + ++cur_idx; + } + } + feat_writer.Write(utt, features); + + if (num_done % 50 == 0 && num_done != 0) + KALDI_VLOG(2) << "Processed " << num_done << " utterances"; + num_done++; + } + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/examples/feat/path.sh b/speechx/examples/feat/path.sh new file mode 100644 index 00000000000..8ab7ee29918 --- /dev/null +++ b/speechx/examples/feat/path.sh @@ -0,0 +1,14 @@ +# This contains the locations of binarys build required for running the examples. + +SPEECHX_ROOT=$PWD/../.. +SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples + +SPEECHX_TOOLS=$SPEECHX_ROOT/tools +TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin + +[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } + +export LC_AL=C + +SPEECHX_BIN=$SPEECHX_EXAMPLES/feat +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/feat/run.sh b/speechx/examples/feat/run.sh new file mode 100755 index 00000000000..bd21bd7f4e1 --- /dev/null +++ b/speechx/examples/feat/run.sh @@ -0,0 +1,31 @@ +#!/bin/bash +set +x +set -e + +. ./path.sh + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + +# 2. download model +if [ ! -d ../paddle_asr_model ]; then + wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz + tar xzfv paddle_asr_model.tar.gz + mv ./paddle_asr_model ../ + # produce wav scp + echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp +fi + +model_dir=../paddle_asr_model +feat_wspecifier=./feats.ark +cmvn=./cmvn.ark + +# 3. run feat +linear_spectrogram_main \ + --wav_rspecifier=scp:$model_dir/wav.scp \ + --feature_wspecifier=ark,t:$feat_wspecifier \ + --cmvn_write_path=$cmvn diff --git a/speechx/examples/feat/valgrind.sh b/speechx/examples/feat/valgrind.sh new file mode 100755 index 00000000000..f8aab63f8c9 --- /dev/null +++ b/speechx/examples/feat/valgrind.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# this script is for memory check, so please run ./run.sh first. + +set +x +set -e + +. ./path.sh + +if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then + echo "please install valgrind in the speechx tools dir.\n" + exit 1 +fi + +model_dir=../paddle_asr_model +feat_wspecifier=./feats.ark +cmvn=./cmvn.ark + +valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \ + linear_spectrogram_main \ + --wav_rspecifier=scp:$model_dir/wav.scp \ + --feature_wspecifier=ark,t:$feat_wspecifier \ + --cmvn_write_path=$cmvn + diff --git a/speechx/examples/nnet/CMakeLists.txt b/speechx/examples/nnet/CMakeLists.txt new file mode 100644 index 00000000000..20f4008ce53 --- /dev/null +++ b/speechx/examples/nnet/CMakeLists.txt @@ -0,0 +1,5 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc) +target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS}) \ No newline at end of file diff --git a/speechx/examples/nnet/path.sh b/speechx/examples/nnet/path.sh new file mode 100644 index 00000000000..f70e70eeaa1 --- /dev/null +++ b/speechx/examples/nnet/path.sh @@ -0,0 +1,14 @@ +# This contains the locations of binarys build required for running the examples. + +SPEECHX_ROOT=$PWD/../.. +SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples + +SPEECHX_TOOLS=$SPEECHX_ROOT/tools +TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin + +[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } + +export LC_AL=C + +SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/nnet/pp-model-test.cc b/speechx/examples/nnet/pp-model-test.cc new file mode 100644 index 00000000000..2db354a79a6 --- /dev/null +++ b/speechx/examples/nnet/pp-model-test.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle_inference_api.h" + +using std::cout; +using std::endl; + +DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel"); +DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams"); + + +void produce_data(std::vector>* data); +void model_forward_test(); + +void produce_data(std::vector>* data) { + int chunk_size = 35; // chunk_size in frame + int col_size = 161; // feat dim + cout << "chunk size: " << chunk_size << endl; + cout << "feat dim: " << col_size << endl; + + data->reserve(chunk_size); + data->back().reserve(col_size); + for (int row = 0; row < chunk_size; ++row) { + data->push_back(std::vector()); + for (int col_idx = 0; col_idx < col_size; ++col_idx) { + data->back().push_back(0.201); + } + } +} + +void model_forward_test() { + std::cout << "1. read the data" << std::endl; + std::vector> feats; + produce_data(&feats); + + std::cout << "2. load the model" << std::endl; + ; + std::string model_graph = FLAGS_model_path; + std::string model_params = FLAGS_param_path; + cout << "model path: " << model_graph << endl; + cout << "model param path : " << model_params << endl; + + paddle_infer::Config config; + config.SetModel(model_graph, model_params); + config.SwitchIrOptim(false); + cout << "SwitchIrOptim: " << false << endl; + config.DisableFCPadding(); + cout << "DisableFCPadding: " << endl; + auto predictor = paddle_infer::CreatePredictor(config); + + std::cout << "3. feat shape, row=" << feats.size() + << ",col=" << feats[0].size() << std::endl; + std::vector pp_input_mat; + for (const auto& item : feats) { + pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end()); + } + + std::cout << "4. fead the data to model" << std::endl; + int row = feats.size(); + int col = feats[0].size(); + std::vector input_names = predictor->GetInputNames(); + std::vector output_names = predictor->GetOutputNames(); + for (auto name : input_names) { + cout << "model input names: " << name << endl; + } + for (auto name : output_names) { + cout << "model output names: " << name << endl; + } + + // input + std::unique_ptr input_tensor = + predictor->GetInputHandle(input_names[0]); + std::vector INPUT_SHAPE = {1, row, col}; + input_tensor->Reshape(INPUT_SHAPE); + input_tensor->CopyFromCpu(pp_input_mat.data()); + + // input length + std::unique_ptr input_len = + predictor->GetInputHandle(input_names[1]); + std::vector input_len_size = {1}; + input_len->Reshape(input_len_size); + std::vector audio_len; + audio_len.push_back(row); + input_len->CopyFromCpu(audio_len.data()); + + // state_h + std::unique_ptr chunk_state_h_box = + predictor->GetInputHandle(input_names[2]); + std::vector chunk_state_h_box_shape = {3, 1, 1024}; + chunk_state_h_box->Reshape(chunk_state_h_box_shape); + int chunk_state_h_box_size = + std::accumulate(chunk_state_h_box_shape.begin(), + chunk_state_h_box_shape.end(), + 1, + std::multiplies()); + std::vector chunk_state_h_box_data(chunk_state_h_box_size, 0.0f); + chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data()); + + // state_c + std::unique_ptr chunk_state_c_box = + predictor->GetInputHandle(input_names[3]); + std::vector chunk_state_c_box_shape = {3, 1, 1024}; + chunk_state_c_box->Reshape(chunk_state_c_box_shape); + int chunk_state_c_box_size = + std::accumulate(chunk_state_c_box_shape.begin(), + chunk_state_c_box_shape.end(), + 1, + std::multiplies()); + std::vector chunk_state_c_box_data(chunk_state_c_box_size, 0.0f); + chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data()); + + // run + bool success = predictor->Run(); + + // state_h out + std::unique_ptr h_out = + predictor->GetOutputHandle(output_names[2]); + std::vector h_out_shape = h_out->shape(); + int h_out_size = std::accumulate( + h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies()); + std::vector h_out_data(h_out_size); + h_out->CopyToCpu(h_out_data.data()); + + // stage_c out + std::unique_ptr c_out = + predictor->GetOutputHandle(output_names[3]); + std::vector c_out_shape = c_out->shape(); + int c_out_size = std::accumulate( + c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies()); + std::vector c_out_data(c_out_size); + c_out->CopyToCpu(c_out_data.data()); + + // output tensor + std::unique_ptr output_tensor = + predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_tensor->shape(); + std::vector output_probs; + int output_size = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); + output_probs.resize(output_size); + output_tensor->CopyToCpu(output_probs.data()); + row = output_shape[1]; + col = output_shape[2]; + + // probs + std::vector> probs; + probs.reserve(row); + for (int i = 0; i < row; i++) { + probs.push_back(std::vector()); + probs.back().reserve(col); + + for (int j = 0; j < col; j++) { + probs.back().push_back(output_probs[i * col + j]); + } + } + + std::vector> log_feat = probs; + std::cout << "probs, row: " << log_feat.size() + << " col: " << log_feat[0].size() << std::endl; + for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) { + for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); + ++col_idx) { + std::cout << log_feat[row_idx][col_idx] << " "; + } + std::cout << std::endl; + } +} + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + model_forward_test(); + return 0; +} diff --git a/speechx/examples/nnet/run.sh b/speechx/examples/nnet/run.sh new file mode 100755 index 00000000000..4d67d198842 --- /dev/null +++ b/speechx/examples/nnet/run.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set +x +set -e + +. path.sh + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + +# 2. download model +if [ ! -d ../paddle_asr_model ]; then + wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz + tar xzfv paddle_asr_model.tar.gz + mv ./paddle_asr_model ../ + # produce wav scp + echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp +fi + +model_dir=../paddle_asr_model + +# 4. run decoder +pp-model-test \ + --model_path=$model_dir/avg_1.jit.pdmodel \ + --param_path=$model_dir/avg_1.jit.pdparams + diff --git a/speechx/examples/nnet/valgrind.sh b/speechx/examples/nnet/valgrind.sh new file mode 100755 index 00000000000..2a08c6082f7 --- /dev/null +++ b/speechx/examples/nnet/valgrind.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# this script is for memory check, so please run ./run.sh first. + +set +x +set -e + +. ./path.sh + +if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then + echo "please install valgrind in the speechx tools dir.\n" + exit 1 +fi + +model_dir=../paddle_asr_model + +valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \ + pp-model-test \ + --model_path=$model_dir/avg_1.jit.pdmodel \ + --param_path=$model_dir/avg_1.jit.pdparams \ No newline at end of file diff --git a/speechx/patch/CPPLINT.cfg b/speechx/patch/CPPLINT.cfg new file mode 100644 index 00000000000..51ff339c184 --- /dev/null +++ b/speechx/patch/CPPLINT.cfg @@ -0,0 +1 @@ +exclude_files=.* diff --git a/speechx/patch/openfst/src/include/fst/flags.h b/speechx/patch/openfst/src/include/fst/flags.h new file mode 100644 index 00000000000..b5ec8ff7416 --- /dev/null +++ b/speechx/patch/openfst/src/include/fst/flags.h @@ -0,0 +1,228 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Google-style flag handling declarations and inline definitions. + +#ifndef FST_LIB_FLAGS_H_ +#define FST_LIB_FLAGS_H_ + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" + +using std::string; + +// FLAGS USAGE: +// +// Definition example: +// +// DEFINE_int32(length, 0, "length"); +// +// This defines variable FLAGS_length, initialized to 0. +// +// Declaration example: +// +// DECLARE_int32(length); +// +// SET_FLAGS() can be used to set flags from the command line +// using, for example, '--length=2'. +// +// ShowUsage() can be used to print out command and flag usage. + +// #define DECLARE_bool(name) extern bool FLAGS_ ## name +// #define DECLARE_string(name) extern string FLAGS_ ## name +// #define DECLARE_int32(name) extern int32 FLAGS_ ## name +// #define DECLARE_int64(name) extern int64 FLAGS_ ## name +// #define DECLARE_double(name) extern double FLAGS_ ## name + +template +struct FlagDescription { + FlagDescription(T *addr, const char *doc, const char *type, + const char *file, const T val) + : address(addr), + doc_string(doc), + type_name(type), + file_name(file), + default_value(val) {} + + T *address; + const char *doc_string; + const char *type_name; + const char *file_name; + const T default_value; +}; + +template +class FlagRegister { + public: + static FlagRegister *GetRegister() { + static auto reg = new FlagRegister; + return reg; + } + + const FlagDescription &GetFlagDescription(const string &name) const { + fst::MutexLock l(&flag_lock_); + auto it = flag_table_.find(name); + return it != flag_table_.end() ? it->second : 0; + } + + void SetDescription(const string &name, + const FlagDescription &desc) { + fst::MutexLock l(&flag_lock_); + flag_table_.insert(make_pair(name, desc)); + } + + bool SetFlag(const string &val, bool *address) const { + if (val == "true" || val == "1" || val.empty()) { + *address = true; + return true; + } else if (val == "false" || val == "0") { + *address = false; + return true; + } + else { + return false; + } + } + + bool SetFlag(const string &val, string *address) const { + *address = val; + return true; + } + + bool SetFlag(const string &val, int32 *address) const { + char *p = 0; + *address = strtol(val.c_str(), &p, 0); + return !val.empty() && *p == '\0'; + } + + bool SetFlag(const string &val, int64 *address) const { + char *p = 0; + *address = strtoll(val.c_str(), &p, 0); + return !val.empty() && *p == '\0'; + } + + bool SetFlag(const string &val, double *address) const { + char *p = 0; + *address = strtod(val.c_str(), &p); + return !val.empty() && *p == '\0'; + } + + bool SetFlag(const string &arg, const string &val) const { + for (typename std::map< string, FlagDescription >::const_iterator it = + flag_table_.begin(); + it != flag_table_.end(); + ++it) { + const string &name = it->first; + const FlagDescription &desc = it->second; + if (arg == name) + return SetFlag(val, desc.address); + } + return false; + } + + void GetUsage(std::set> *usage_set) const { + for (auto it = flag_table_.begin(); it != flag_table_.end(); ++it) { + const string &name = it->first; + const FlagDescription &desc = it->second; + string usage = " --" + name; + usage += ": type = "; + usage += desc.type_name; + usage += ", default = "; + usage += GetDefault(desc.default_value) + "\n "; + usage += desc.doc_string; + usage_set->insert(make_pair(desc.file_name, usage)); + } + } + + private: + string GetDefault(bool default_value) const { + return default_value ? "true" : "false"; + } + + string GetDefault(const string &default_value) const { + return "\"" + default_value + "\""; + } + + template + string GetDefault(const V &default_value) const { + std::ostringstream strm; + strm << default_value; + return strm.str(); + } + + mutable fst::Mutex flag_lock_; // Multithreading lock. + std::map> flag_table_; +}; + +template +class FlagRegisterer { + public: + FlagRegisterer(const string &name, const FlagDescription &desc) { + auto registr = FlagRegister::GetRegister(); + registr->SetDescription(name, desc); + } + + private: + FlagRegisterer(const FlagRegisterer &) = delete; + FlagRegisterer &operator=(const FlagRegisterer &) = delete; +}; + + +#define DEFINE_VAR(type, name, value, doc) \ + type FLAGS_ ## name = value; \ + static FlagRegisterer \ + name ## _flags_registerer(#name, FlagDescription(&FLAGS_ ## name, \ + doc, \ + #type, \ + __FILE__, \ + value)) + +// #define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc) +// #define DEFINE_string(name, value, doc) \ +// DEFINE_VAR(string, name, value, doc) +// #define DEFINE_int32(name, value, doc) DEFINE_VAR(int32, name, value, doc) +// #define DEFINE_int64(name, value, doc) DEFINE_VAR(int64, name, value, doc) +// #define DEFINE_double(name, value, doc) DEFINE_VAR(double, name, value, doc) + + +// Temporary directory. +DECLARE_string(tmpdir); + +void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, + const char *src = ""); + +#define SET_FLAGS(usage, argc, argv, rmflags) \ +gflags::ParseCommandLineFlags(argc, argv, true) +// SetFlags(usage, argc, argv, rmflags, __FILE__) + +// Deprecated; for backward compatibility. +inline void InitFst(const char *usage, int *argc, char ***argv, bool rmflags) { + return SetFlags(usage, argc, argv, rmflags); +} + +void ShowUsage(bool long_usage = true); + +#endif // FST_LIB_FLAGS_H_ diff --git a/speechx/patch/openfst/src/include/fst/log.h b/speechx/patch/openfst/src/include/fst/log.h new file mode 100644 index 00000000000..bf041c58ebf --- /dev/null +++ b/speechx/patch/openfst/src/include/fst/log.h @@ -0,0 +1,82 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Google-style logging declarations and inline definitions. + +#ifndef FST_LIB_LOG_H_ +#define FST_LIB_LOG_H_ + +#include +#include +#include + +#include +#include + +using std::string; + +DECLARE_int32(v); + +class LogMessage { + public: + LogMessage(const string &type) : fatal_(type == "FATAL") { + std::cerr << type << ": "; + } + ~LogMessage() { + std::cerr << std::endl; + if(fatal_) + exit(1); + } + std::ostream &stream() { return std::cerr; } + + private: + bool fatal_; +}; + +// #define LOG(type) LogMessage(#type).stream() +// #define VLOG(level) if ((level) <= FLAGS_v) LOG(INFO) + +// Checks +inline void FstCheck(bool x, const char* expr, + const char *file, int line) { + if (!x) { + LOG(FATAL) << "Check failed: \"" << expr + << "\" file: " << file + << " line: " << line; + } +} + +// #define CHECK(x) FstCheck(static_cast(x), #x, __FILE__, __LINE__) +// #define CHECK_EQ(x, y) CHECK((x) == (y)) +// #define CHECK_LT(x, y) CHECK((x) < (y)) +// #define CHECK_GT(x, y) CHECK((x) > (y)) +// #define CHECK_LE(x, y) CHECK((x) <= (y)) +// #define CHECK_GE(x, y) CHECK((x) >= (y)) +// #define CHECK_NE(x, y) CHECK((x) != (y)) + +// Debug checks +// #define DCHECK(x) assert(x) +// #define DCHECK_EQ(x, y) DCHECK((x) == (y)) +// #define DCHECK_LT(x, y) DCHECK((x) < (y)) +// #define DCHECK_GT(x, y) DCHECK((x) > (y)) +// #define DCHECK_LE(x, y) DCHECK((x) <= (y)) +// #define DCHECK_GE(x, y) DCHECK((x) >= (y)) +// #define DCHECK_NE(x, y) DCHECK((x) != (y)) + + +// Ports +#define ATTRIBUTE_DEPRECATED __attribute__((deprecated)) + +#endif // FST_LIB_LOG_H_ diff --git a/speechx/patch/openfst/src/lib/flags.cc b/speechx/patch/openfst/src/lib/flags.cc new file mode 100644 index 00000000000..95f7e2e9a56 --- /dev/null +++ b/speechx/patch/openfst/src/lib/flags.cc @@ -0,0 +1,166 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Google-style flag handling definitions. + +#include + +#if _MSC_VER +#include +#include +#endif + +#include +#include + +static const char *private_tmpdir = getenv("TMPDIR"); + +// DEFINE_int32(v, 0, "verbosity level"); +// DEFINE_bool(help, false, "show usage information"); +// DEFINE_bool(helpshort, false, "show brief usage information"); +#ifndef _MSC_VER +DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : "/tmp", + "temporary directory"); +#else +DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : getenv("TEMP"), + "temporary directory"); +#endif // !_MSC_VER + +using namespace std; + +static string flag_usage; +static string prog_src; + +// Sets prog_src to src. +static void SetProgSrc(const char *src) { + prog_src = src; +#if _MSC_VER + // This common code is invoked by all FST binaries, and only by them. Switch + // stdin and stdout into "binary" mode, so that 0x0A won't be translated into + // a 0x0D 0x0A byte pair in a pipe or a shell redirect. Other streams are + // already using ios::binary where binary files are read or written. + // Kudos to @daanzu for the suggested fix. + // https://github.com/kkm000/openfst/issues/20 + // https://github.com/kkm000/openfst/pull/23 + // https://github.com/kkm000/openfst/pull/32 + _setmode(_fileno(stdin), O_BINARY); + _setmode(_fileno(stdout), O_BINARY); +#endif + // Remove "-main" in src filename. Flags are defined in fstx.cc but SetFlags() + // is called in fstx-main.cc, which results in a filename mismatch in + // ShowUsageRestrict() below. + static constexpr char kMainSuffix[] = "-main.cc"; + const int prefix_length = prog_src.size() - strlen(kMainSuffix); + if (prefix_length > 0 && prog_src.substr(prefix_length) == kMainSuffix) { + prog_src.erase(prefix_length, strlen("-main")); + } +} + +void SetFlags(const char *usage, int *argc, char ***argv, + bool remove_flags, const char *src) { + flag_usage = usage; + SetProgSrc(src); + + int index = 1; + for (; index < *argc; ++index) { + string argval = (*argv)[index]; + if (argval[0] != '-' || argval == "-") break; + while (argval[0] == '-') argval = argval.substr(1); // Removes initial '-'. + string arg = argval; + string val = ""; + // Splits argval (arg=val) into arg and val. + auto pos = argval.find("="); + if (pos != string::npos) { + arg = argval.substr(0, pos); + val = argval.substr(pos + 1); + } + auto bool_register = FlagRegister::GetRegister(); + if (bool_register->SetFlag(arg, val)) + continue; + auto string_register = FlagRegister::GetRegister(); + if (string_register->SetFlag(arg, val)) + continue; + auto int32_register = FlagRegister::GetRegister(); + if (int32_register->SetFlag(arg, val)) + continue; + auto int64_register = FlagRegister::GetRegister(); + if (int64_register->SetFlag(arg, val)) + continue; + auto double_register = FlagRegister::GetRegister(); + if (double_register->SetFlag(arg, val)) + continue; + LOG(FATAL) << "SetFlags: Bad option: " << (*argv)[index]; + } + if (remove_flags) { + for (auto i = 0; i < *argc - index; ++i) { + (*argv)[i + 1] = (*argv)[i + index]; + } + *argc -= index - 1; + } + // if (FLAGS_help) { + // ShowUsage(true); + // exit(1); + // } + // if (FLAGS_helpshort) { + // ShowUsage(false); + // exit(1); + // } +} + +// If flag is defined in file 'src' and 'in_src' true or is not +// defined in file 'src' and 'in_src' is false, then print usage. +static void +ShowUsageRestrict(const std::set> &usage_set, + const string &src, bool in_src, bool show_file) { + string old_file; + bool file_out = false; + bool usage_out = false; + for (const auto &pair : usage_set) { + const auto &file = pair.first; + const auto &usage = pair.second; + bool match = file == src; + if ((match && !in_src) || (!match && in_src)) continue; + if (file != old_file) { + if (show_file) { + if (file_out) cout << "\n"; + cout << "Flags from: " << file << "\n"; + file_out = true; + } + old_file = file; + } + cout << usage << "\n"; + usage_out = true; + } + if (usage_out) cout << "\n"; +} + +void ShowUsage(bool long_usage) { + std::set> usage_set; + cout << flag_usage << "\n"; + auto bool_register = FlagRegister::GetRegister(); + bool_register->GetUsage(&usage_set); + auto string_register = FlagRegister::GetRegister(); + string_register->GetUsage(&usage_set); + auto int32_register = FlagRegister::GetRegister(); + int32_register->GetUsage(&usage_set); + auto int64_register = FlagRegister::GetRegister(); + int64_register->GetUsage(&usage_set); + auto double_register = FlagRegister::GetRegister(); + double_register->GetUsage(&usage_set); + if (!prog_src.empty()) { + cout << "PROGRAM FLAGS:\n\n"; + ShowUsageRestrict(usage_set, prog_src, true, false); + } + if (!long_usage) return; + if (!prog_src.empty()) cout << "LIBRARY FLAGS:\n\n"; + ShowUsageRestrict(usage_set, prog_src, false, true); +} diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index 71c7eb7cdcb..225abee7cec 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -2,13 +2,32 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) project(speechx LANGUAGES CXX) -link_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/openblas) - include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/kaldi ) add_subdirectory(kaldi) -add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc) -target_link_libraries(mfcc-test kaldi-mfcc) +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/utils +) +add_subdirectory(utils) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/frontend +) +add_subdirectory(frontend) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/nnet +) +add_subdirectory(nnet) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/decoder +) +add_subdirectory(decoder) \ No newline at end of file diff --git a/speechx/speechx/base/basic_types.h b/speechx/speechx/base/basic_types.h index 1966c021e4f..206b7be6754 100644 --- a/speechx/speechx/base/basic_types.h +++ b/speechx/speechx/base/basic_types.h @@ -16,45 +16,45 @@ #include "kaldi/base/kaldi-types.h" -#include +#include -typedef float BaseFloat; -typedef double double64; +typedef float BaseFloat; +typedef double double64; -typedef signed char int8; -typedef short int16; -typedef int int32; +typedef signed char int8; +typedef short int16; +typedef int int32; #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) -typedef long int64; +typedef long int64; #else -typedef long long int64; +typedef long long int64; #endif -typedef unsigned char uint8; -typedef unsigned short uint16; -typedef unsigned int uint32; +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned int uint32; -if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) +#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) typedef unsigned long uint64; #else typedef unsigned long long uint64; #endif -typedef signed int char32; - -const uint8 kuint8max = (( uint8) 0xFF); -const uint16 kuint16max = ((uint16) 0xFFFF); -const uint32 kuint32max = ((uint32) 0xFFFFFFFF); -const uint64 kuint64max = ((uint64) (0xFFFFFFFFFFFFFFFFLL)); -const int8 kint8min = (( int8) 0x80); -const int8 kint8max = (( int8) 0x7F); -const int16 kint16min = (( int16) 0x8000); -const int16 kint16max = (( int16) 0x7FFF); -const int32 kint32min = (( int32) 0x80000000); -const int32 kint32max = (( int32) 0x7FFFFFFF); -const int64 kint64min = (( int64) (0x8000000000000000LL)); -const int64 kint64max = (( int64) (0x7FFFFFFFFFFFFFFFLL)); - -const BaseFloat kBaseFloatMax = std::numeric_limits::max(); -const BaseFloat kBaseFloatMin = std::numeric_limits::min(); +typedef signed int char32; + +const uint8 kuint8max = ((uint8)0xFF); +const uint16 kuint16max = ((uint16)0xFFFF); +const uint32 kuint32max = ((uint32)0xFFFFFFFF); +const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL)); +const int8 kint8min = ((int8)0x80); +const int8 kint8max = ((int8)0x7F); +const int16 kint16min = ((int16)0x8000); +const int16 kint16max = ((int16)0x7FFF); +const int32 kint32min = ((int32)0x80000000); +const int32 kint32max = ((int32)0x7FFFFFFF); +const int64 kint64min = ((int64)(0x8000000000000000LL)); +const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL)); + +const BaseFloat kBaseFloatMax = std::numeric_limits::max(); +const BaseFloat kBaseFloatMin = std::numeric_limits::min(); diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h new file mode 100644 index 00000000000..7502bc5eb8a --- /dev/null +++ b/speechx/speechx/base/common.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "base/basic_types.h" +#include "base/flags.h" +#include "base/log.h" +#include "base/macros.h" diff --git a/speechx/speechx/base/flags.h b/speechx/speechx/base/flags.h new file mode 100644 index 00000000000..41df0d452e5 --- /dev/null +++ b/speechx/speechx/base/flags.h @@ -0,0 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fst/flags.h" diff --git a/speechx/speechx/base/log.h b/speechx/speechx/base/log.h new file mode 100644 index 00000000000..c613b98c34d --- /dev/null +++ b/speechx/speechx/base/log.h @@ -0,0 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fst/log.h" diff --git a/speechx/speechx/base/macros.h b/speechx/speechx/base/macros.h index c8d254d667b..d7d5a78d102 100644 --- a/speechx/speechx/base/macros.h +++ b/speechx/speechx/base/macros.h @@ -16,8 +16,10 @@ namespace ppspeech { +#ifndef DISALLOW_COPY_AND_ASSIGN #define DISALLOW_COPY_AND_ASSIGN(TypeName) \ - TypeName(const TypeName&) = delete; \ - void operator=(const TypeName&) = delete + TypeName(const TypeName&) = delete; \ + void operator=(const TypeName&) = delete +#endif } // namespace pp_speech \ No newline at end of file diff --git a/speechx/speechx/base/thread_pool.h b/speechx/speechx/base/thread_pool.h new file mode 100644 index 00000000000..ba895f7147f --- /dev/null +++ b/speechx/speechx/base/thread_pool.h @@ -0,0 +1,110 @@ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. +// this code is from https://github.com/progschj/ThreadPool + +#ifndef BASE_THREAD_POOL_H +#define BASE_THREAD_POOL_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue> tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait(lock, [this] { + return this->stop || !this->tasks.empty(); + }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) worker.join(); +} + +#endif diff --git a/speechx/speechx/codelab/README.md b/speechx/speechx/codelab/README.md deleted file mode 100644 index 95c95db13dc..00000000000 --- a/speechx/speechx/codelab/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# codelab - -This directory is here for testing some funcitons temporaril. - diff --git a/speechx/speechx/codelab/feat_test/feature-mfcc-test.cc b/speechx/speechx/codelab/feat_test/feature-mfcc-test.cc deleted file mode 100644 index c4367139707..00000000000 --- a/speechx/speechx/codelab/feat_test/feature-mfcc-test.cc +++ /dev/null @@ -1,686 +0,0 @@ -// feat/feature-mfcc-test.cc - -// Copyright 2009-2011 Karel Vesely; Petr Motlicek - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include - -#include "feat/feature-mfcc.h" -#include "base/kaldi-math.h" -#include "matrix/kaldi-matrix-inl.h" -#include "feat/wave-reader.h" - -using namespace kaldi; - - - -static void UnitTestReadWave() { - - std::cout << "=== UnitTestReadWave() ===\n"; - - Vector v, v2; - - std::cout << "<<<=== Reading waveform\n"; - - { - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - const Matrix data(wave.Data()); - KALDI_ASSERT(data.NumRows() == 1); - v.Resize(data.NumCols()); - v.CopyFromVec(data.Row(0)); - } - - std::cout << "<<<=== Reading Vector waveform, prepared by matlab\n"; - std::ifstream input( - "test_data/test_matlab.ascii" - ); - KALDI_ASSERT(input.good()); - v2.Read(input, false); - input.close(); - - std::cout << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n"; - KALDI_ASSERT(v.Dim() == v2.Dim()); - for (int32 i = 0; i < v.Dim(); i++) { - KALDI_ASSERT(v(i) == v2(i)); - } - std::cout << "<<<=== Comparing done\n"; - - // std::cout << "== The Waveform Samples == \n"; - // std::cout << v; - - std::cout << "Test passed :)\n\n"; - -} - - - -/** - */ -static void UnitTestSimple() { - std::cout << "=== UnitTestSimple() ===\n"; - - Vector v(100000); - Matrix m; - - // init with noise - for (int32 i = 0; i < v.Dim(); i++) { - v(i) = (abs( i * 433024253 ) % 65535) - (65535 / 2); - } - - std::cout << "<<<=== Just make sure it runs... Nothing is compared\n"; - // the parametrization object - MfccOptions op; - // trying to have same opts as baseline. - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "rectangular"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.low_freq = 0.0; - op.mel_opts.htk_mode = true; - op.htk_compat = true; - - Mfcc mfcc(op); - // use default parameters - - // compute mfccs. - mfcc.Compute(v, 1.0, &m); - - // possibly dump - // std::cout << "== Output features == \n" << m; - std::cout << "Test passed :)\n\n"; -} - - -static void UnitTestHTKCompare1() { - std::cout << "=== UnitTestHTKCompare1() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.1", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.low_freq = 0.0; - op.mel_opts.htk_mode = true; - op.htk_compat = true; - op.use_energy = false; // C0 not energy. - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (i_old != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.1", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.1"); -} - - -static void UnitTestHTKCompare2() { - std::cout << "=== UnitTestHTKCompare2() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.2", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.low_freq = 0.0; - op.mel_opts.htk_mode = true; - op.htk_compat = true; - op.use_energy = true; // Use energy. - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (i_old != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.2", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.2"); -} - - -static void UnitTestHTKCompare3() { - std::cout << "=== UnitTestHTKCompare3() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.3", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.htk_compat = true; - op.use_energy = true; // Use energy. - op.mel_opts.low_freq = 20.0; - //op.mel_opts.debug_mel = true; - op.mel_opts.htk_mode = true; - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (static_cast(i_old) != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.3", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.3"); -} - - -static void UnitTestHTKCompare4() { - std::cout << "=== UnitTestHTKCompare4() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.4", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.low_freq = 0.0; - op.htk_compat = true; - op.use_energy = true; // Use energy. - op.mel_opts.htk_mode = true; - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (static_cast(i_old) != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.4", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.4"); -} - - -static void UnitTestHTKCompare5() { - std::cout << "=== UnitTestHTKCompare5() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.5", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.htk_compat = true; - op.use_energy = true; // Use energy. - op.mel_opts.low_freq = 0.0; - op.mel_opts.vtln_low = 100.0; - op.mel_opts.vtln_high = 7500.0; - op.mel_opts.htk_mode = true; - - BaseFloat vtln_warp = 1.1; // our approach identical to htk for warp factor >1, - // differs slightly for higher mel bins if warp_factor <0.9 - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (static_cast(i_old) != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.5", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.5"); -} - -static void UnitTestHTKCompare6() { - std::cout << "=== UnitTestHTKCompare6() ===\n"; - - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.6", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.97; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.num_bins = 24; - op.mel_opts.low_freq = 125.0; - op.mel_opts.high_freq = 7800.0; - op.htk_compat = true; - op.use_energy = false; // C0 not energy. - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (static_cast(i_old) != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.6", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.6"); -} - -void UnitTestVtln() { - // Test the function VtlnWarpFreq. - BaseFloat low_freq = 10, high_freq = 7800, - vtln_low_cutoff = 20, vtln_high_cutoff = 7400; - - for (size_t i = 0; i < 100; i++) { - BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2; - AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, warp_factor, - freq), - freq / warp_factor); - - AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, warp_factor, - low_freq), - low_freq); - AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, warp_factor, - high_freq), - high_freq); - BaseFloat freq2 = low_freq + (high_freq-low_freq) * RandUniform(), - freq3 = freq2 + (high_freq-freq2) * RandUniform(); // freq3>=freq2 - BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, warp_factor, - freq2); - BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, warp_factor, - freq3); - KALDI_ASSERT(w3 >= w2); // increasing function. - BaseFloat w3dash = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, 1.0, - freq3); - AssertEqual(w3dash, freq3); - } -} - -static void UnitTestFeat() { - UnitTestVtln(); - UnitTestReadWave(); - UnitTestSimple(); - UnitTestHTKCompare1(); - UnitTestHTKCompare2(); - // commenting out this one as it doesn't compare right now I normalized - // the way the FFT bins are treated (removed offset of 0.5)... this seems - // to relate to the way frequency zero behaves. - UnitTestHTKCompare3(); - UnitTestHTKCompare4(); - UnitTestHTKCompare5(); - UnitTestHTKCompare6(); - std::cout << "Tests succeeded.\n"; -} - - - -int main() { - try { - for (int i = 0; i < 5; i++) - UnitTestFeat(); - std::cout << "Tests succeeded.\n"; - return 0; - } catch (const std::exception &e) { - std::cerr << e.what(); - return 1; - } -} - - diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 259261bdf12..7cd281b66f3 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -1,2 +1,10 @@ -aux_source_directory(. DIR_LIB_SRCS) -add_library(decoder STATIC ${DIR_LIB_SRCS}) +project(decoder) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) +add_library(decoder STATIC + ctc_beam_search_decoder.cc + ctc_decoders/decoder_utils.cpp + ctc_decoders/path_trie.cpp + ctc_decoders/scorer.cpp +) +target_link_libraries(decoder PUBLIC kenlm utils fst) \ No newline at end of file diff --git a/speechx/speechx/decoder/common.h b/speechx/speechx/decoder/common.h new file mode 100644 index 00000000000..52deffac9e5 --- /dev/null +++ b/speechx/speechx/decoder/common.h @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/basic_types.h" + +struct DecoderResult { + BaseFloat acoustic_score; + std::vector words_idx; + std::vector> time_stamp; +}; diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc new file mode 100644 index 00000000000..84f1453c0f6 --- /dev/null +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -0,0 +1,314 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/ctc_beam_search_decoder.h" + +#include "base/basic_types.h" +#include "decoder/ctc_decoders/decoder_utils.h" +#include "utils/file_utils.h" + +namespace ppspeech { + +using std::vector; +using FSTMATCH = fst::SortedMatcher; + +CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) + : opts_(opts), + init_ext_scorer_(nullptr), + blank_id_(-1), + space_id_(-1), + num_frame_decoded_(0), + root_(nullptr) { + LOG(INFO) << "dict path: " << opts_.dict_file; + if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { + LOG(INFO) << "load the dict failed"; + } + LOG(INFO) << "read the vocabulary success, dict size: " + << vocabulary_.size(); + + LOG(INFO) << "language model path: " << opts_.lm_path; + init_ext_scorer_ = std::make_shared( + opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); + + blank_id_ = 0; + auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); + + space_id_ = it - vocabulary_.begin(); + // if no space in vocabulary + if ((size_t)space_id_ >= vocabulary_.size()) { + space_id_ = -2; + } +} + +void CTCBeamSearch::Reset() { + // num_frame_decoded_ = 0; + // ResetPrefixes(); + InitDecoder(); +} + +void CTCBeamSearch::InitDecoder() { + num_frame_decoded_ = 0; + // ResetPrefixes(); + prefixes_.clear(); + + root_ = std::make_shared(); + root_->score = root_->log_prob_b_prev = 0.0; + prefixes_.push_back(root_.get()); + if (init_ext_scorer_ != nullptr && + !init_ext_scorer_->is_character_based()) { + auto fst_dict = + static_cast(init_ext_scorer_->dictionary); + fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); + root_->set_dictionary(dict_ptr); + + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + root_->set_matcher(matcher); + } +} + +void CTCBeamSearch::Decode( + std::shared_ptr decodable) { + return; +} + +int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; } + +// todo rename, refactor +void CTCBeamSearch::AdvanceDecode( + const std::shared_ptr& decodable) { + while (1) { + vector> likelihood; + vector frame_prob; + bool flag = + decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob); + if (flag == false) break; + likelihood.push_back(frame_prob); + AdvanceDecoding(likelihood); + } +} + +void CTCBeamSearch::ResetPrefixes() { + for (size_t i = 0; i < prefixes_.size(); i++) { + if (prefixes_[i] != nullptr) { + delete prefixes_[i]; + prefixes_[i] = nullptr; + } + } + prefixes_.clear(); +} + +int CTCBeamSearch::DecodeLikelihoods(const vector>& probs, + vector& nbest_words) { + kaldi::Timer timer; + timer.Reset(); + AdvanceDecoding(probs); + LOG(INFO) << "ctc decoding elapsed time(s) " + << static_cast(timer.Elapsed()) / 1000.0f; + return 0; +} + +vector> CTCBeamSearch::GetNBestPath() { + return get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); +} + +string CTCBeamSearch::GetBestPath() { + std::vector> result; + result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); + return result[0].second; +} + +string CTCBeamSearch::GetFinalBestPath() { + CalculateApproxScore(); + LMRescore(); + return GetBestPath(); +} + +void CTCBeamSearch::AdvanceDecoding(const vector>& probs) { + size_t num_time_steps = probs.size(); + size_t beam_size = opts_.beam_size; + double cutoff_prob = opts_.cutoff_prob; + size_t cutoff_top_n = opts_.cutoff_top_n; + + vector> probs_seq(probs.size(), + vector(probs[0].size(), 0)); + + int row = probs.size(); + int col = probs[0].size(); + for (int i = 0; i < row; i++) { + for (int j = 0; j < col; j++) { + probs_seq[i][j] = static_cast(probs[i][j]); + } + } + + for (size_t time_step = 0; time_step < num_time_steps; time_step++) { + const auto& prob = probs_seq[time_step]; + + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (init_ext_scorer_ != nullptr) { + size_t num_prefixes_ = std::min(prefixes_.size(), beam_size); + std::sort(prefixes_.begin(), + prefixes_.begin() + num_prefixes_, + prefix_compare); + + if (num_prefixes_ == 0) { + continue; + } + min_cutoff = prefixes_[num_prefixes_ - 1]->score + + std::log(prob[blank_id_]) - + std::max(0.0, init_ext_scorer_->beta); + + full_beam = (num_prefixes_ == beam_size); + } + + vector> log_prob_idx = + get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); + + // loop over chars + size_t log_prob_idx_len = log_prob_idx.size(); + for (size_t index = 0; index < log_prob_idx_len; index++) { + SearchOneChar(full_beam, log_prob_idx[index], min_cutoff); + } + + prefixes_.clear(); + + // update log probs + root_->iterate_to_vec(prefixes_); + // only preserve top beam_size prefixes_ + if (prefixes_.size() >= beam_size) { + std::nth_element(prefixes_.begin(), + prefixes_.begin() + beam_size, + prefixes_.end(), + prefix_compare); + for (size_t i = beam_size; i < prefixes_.size(); ++i) { + prefixes_[i]->remove(); + } + } // if + num_frame_decoded_++; + } // for probs_seq +} + +int32 CTCBeamSearch::SearchOneChar( + const bool& full_beam, + const std::pair& log_prob_idx, + const BaseFloat& min_cutoff) { + size_t beam_size = opts_.beam_size; + const auto& c = log_prob_idx.first; + const auto& log_prob_c = log_prob_idx.second; + size_t prefixes_len = std::min(prefixes_.size(), beam_size); + + for (size_t i = 0; i < prefixes_len; ++i) { + auto prefix = prefixes_[i]; + if (full_beam && log_prob_c + prefix->score < min_cutoff) { + break; + } + + if (c == blank_id_) { + prefix->log_prob_b_cur = + log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); + continue; + } + + // repeated character + if (c == prefix->character) { + // p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1}) + prefix->log_prob_nb_cur = log_sum_exp( + prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); + } + + // get new prefix + auto prefix_new = prefix->get_path_trie(c); + if (prefix_new != nullptr) { + float log_p = -NUM_FLT_INF; + if (c == prefix->character && + prefix->log_prob_b_prev > -NUM_FLT_INF) { + // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1}) + log_p = log_prob_c + prefix->log_prob_b_prev; + } else if (c != prefix->character) { + // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1}) + log_p = log_prob_c + prefix->score; + } + + // language model scoring + if (init_ext_scorer_ != nullptr && + (c == space_id_ || init_ext_scorer_->is_character_based())) { + PathTrie* prefix_to_score = nullptr; + // skip scoring the space + if (init_ext_scorer_->is_character_based()) { + prefix_to_score = prefix_new; + } else { + prefix_to_score = prefix; + } + + float score = 0.0; + vector ngram; + ngram = init_ext_scorer_->make_ngram(prefix_to_score); + // lm score: p_{lm}(W)^{\alpha} + \beta + score = init_ext_scorer_->get_log_cond_prob(ngram) * + init_ext_scorer_->alpha; + log_p += score; + log_p += init_ext_scorer_->beta; + } + // p_{nb}(l;x_{1:t}) + prefix_new->log_prob_nb_cur = + log_sum_exp(prefix_new->log_prob_nb_cur, log_p); + } + } // end of loop over prefix + return 0; +} + +void CTCBeamSearch::CalculateApproxScore() { + size_t beam_size = opts_.beam_size; + size_t num_prefixes_ = std::min(prefixes_.size(), beam_size); + std::sort( + prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare); + + // compute aproximate ctc score as the return score, without affecting the + // return order of decoding result. To delete when decoder gets stable. + for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) { + double approx_ctc = prefixes_[i]->score; + if (init_ext_scorer_ != nullptr) { + vector output; + prefixes_[i]->get_path_vec(output); + auto prefix_length = output.size(); + auto words = init_ext_scorer_->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta; + // remove language model weight: + approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) * + init_ext_scorer_->alpha; + } + prefixes_[i]->approx_ctc = approx_ctc; + } +} + +void CTCBeamSearch::LMRescore() { + size_t beam_size = opts_.beam_size; + if (init_ext_scorer_ != nullptr && + !init_ext_scorer_->is_character_based()) { + for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) { + auto prefix = prefixes_[i]; + if (!prefix->is_empty() && prefix->character != space_id_) { + float score = 0.0; + vector ngram = init_ext_scorer_->make_ngram(prefix); + score = init_ext_scorer_->get_log_cond_prob(ngram) * + init_ext_scorer_->alpha; + score += init_ext_scorer_->beta; + prefix->score += score; + } + } + } +} + +} // namespace ppspeech diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h new file mode 100644 index 00000000000..451f33c0319 --- /dev/null +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/common.h" +#include "decoder/ctc_decoders/path_trie.h" +#include "decoder/ctc_decoders/scorer.h" +#include "nnet/decodable-itf.h" +#include "util/parse-options.h" + +#pragma once + +namespace ppspeech { + +struct CTCBeamSearchOptions { + std::string dict_file; + std::string lm_path; + BaseFloat alpha; + BaseFloat beta; + BaseFloat cutoff_prob; + int beam_size; + int cutoff_top_n; + int num_proc_bsearch; + CTCBeamSearchOptions() + : dict_file("vocab.txt"), + lm_path("lm.klm"), + alpha(1.9f), + beta(5.0), + beam_size(300), + cutoff_prob(0.99f), + cutoff_top_n(40), + num_proc_bsearch(0) {} + + void Register(kaldi::OptionsItf* opts) { + opts->Register("dict", &dict_file, "dict file "); + opts->Register("lm-path", &lm_path, "language model file"); + opts->Register("alpha", &alpha, "alpha"); + opts->Register("beta", &beta, "beta"); + opts->Register( + "beam-size", &beam_size, "beam size for beam search method"); + opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs"); + opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n"); + opts->Register( + "num-proc-bsearch", &num_proc_bsearch, "num proc bsearch"); + } +}; + +class CTCBeamSearch { + public: + explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); + ~CTCBeamSearch() {} + void InitDecoder(); + void Decode(std::shared_ptr decodable); + std::string GetBestPath(); + std::vector> GetNBestPath(); + std::string GetFinalBestPath(); + int NumFrameDecoded(); + int DecodeLikelihoods(const std::vector>& probs, + std::vector& nbest_words); + void AdvanceDecode( + const std::shared_ptr& decodable); + void Reset(); + + private: + void ResetPrefixes(); + int32 SearchOneChar(const bool& full_beam, + const std::pair& log_prob_idx, + const BaseFloat& min_cutoff); + void CalculateApproxScore(); + void LMRescore(); + void AdvanceDecoding(const std::vector>& probs); + + CTCBeamSearchOptions opts_; + std::shared_ptr init_ext_scorer_; // todo separate later + std::vector vocabulary_; // todo remove later + size_t blank_id_; + int space_id_; + std::shared_ptr root_; + std::vector prefixes_; + int num_frame_decoded_; + DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); +}; + +} // namespace basr \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_decoders b/speechx/speechx/decoder/ctc_decoders new file mode 120000 index 00000000000..b280de09681 --- /dev/null +++ b/speechx/speechx/decoder/ctc_decoders @@ -0,0 +1 @@ +../../../third_party/ctc_decoders \ No newline at end of file diff --git a/speechx/speechx/frontend/CMakeLists.txt b/speechx/speechx/frontend/CMakeLists.txt index e69de29bb2d..44ca52cdc08 100644 --- a/speechx/speechx/frontend/CMakeLists.txt +++ b/speechx/speechx/frontend/CMakeLists.txt @@ -0,0 +1,10 @@ +project(frontend) + +add_library(frontend STATIC + normalizer.cc + linear_spectrogram.cc + raw_audio.cc + feature_cache.cc +) + +target_link_libraries(frontend PUBLIC kaldi-matrix) diff --git a/speechx/speechx/frontend/fbank.h b/speechx/speechx/frontend/fbank.h new file mode 100644 index 00000000000..7d9cf422125 --- /dev/null +++ b/speechx/speechx/frontend/fbank.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// wrap the fbank feat of kaldi, todo (SmileGoat) + +#include "kaldi/feat/feature-mfcc.h" + +#incldue "kaldi/matrix/kaldi-vector.h" + +namespace ppspeech { + +class FbankExtractor : FeatureExtractorInterface { + public: + explicit FbankExtractor(const FbankOptions& opts, + share_ptr pre_extractor); + virtual void AcceptWaveform( + const kaldi::Vector& input) = 0; + virtual void Read(kaldi::Vector* feat) = 0; + virtual size_t Dim() const = 0; + + private: + bool Compute(const kaldi::Vector& wave, + kaldi::Vector* feat) const; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/feature_cache.cc b/speechx/speechx/frontend/feature_cache.cc new file mode 100644 index 00000000000..d23b3a8b29d --- /dev/null +++ b/speechx/speechx/frontend/feature_cache.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/feature_cache.h" + +namespace ppspeech { + +using kaldi::Vector; +using kaldi::VectorBase; +using kaldi::BaseFloat; +using std::vector; +using kaldi::SubVector; +using std::unique_ptr; + +FeatureCache::FeatureCache( + int max_size, unique_ptr base_extractor) { + max_size_ = max_size; + base_extractor_ = std::move(base_extractor); +} + +void FeatureCache::Accept(const kaldi::VectorBase& inputs) { + base_extractor_->Accept(inputs); + // feed current data + bool result = false; + do { + result = Compute(); + } while (result); +} + +// pop feature chunk +bool FeatureCache::Read(kaldi::Vector* feats) { + kaldi::Timer timer; + std::unique_lock lock(mutex_); + while (cache_.empty() && base_extractor_->IsFinished() == false) { + ready_read_condition_.wait(lock); + BaseFloat elapsed = timer.Elapsed() * 1000; + // todo replace 1.0 with timeout_ + if (elapsed > 1.0) { + return false; + } + usleep(1000); // sleep 1 ms + } + if (cache_.empty()) return false; + feats->Resize(cache_.front().Dim()); + feats->CopyFromVec(cache_.front()); + cache_.pop(); + ready_feed_condition_.notify_one(); + return true; +} + +// read all data from base_feature_extractor_ into cache_ +bool FeatureCache::Compute() { + // compute and feed + Vector feature_chunk; + bool result = base_extractor_->Read(&feature_chunk); + std::unique_lock lock(mutex_); + while (cache_.size() >= max_size_) { + ready_feed_condition_.wait(lock); + } + if (feature_chunk.Dim() != 0) { + cache_.push(feature_chunk); + } + ready_read_condition_.notify_one(); + return result; +} + +void Reset() { + // std::lock_guard lock(mutex_); + return; +} + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/feature_cache.h b/speechx/speechx/frontend/feature_cache.h new file mode 100644 index 00000000000..e52d8b2981a --- /dev/null +++ b/speechx/speechx/frontend/feature_cache.h @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "frontend/feature_extractor_interface.h" + +namespace ppspeech { + +class FeatureCache : public FeatureExtractorInterface { + public: + explicit FeatureCache( + int32 max_size = kint16max, + std::unique_ptr base_extractor = NULL); + virtual void Accept(const kaldi::VectorBase& inputs); + // feats dim = num_frames * feature_dim + virtual bool Read(kaldi::Vector* feats); + // feature cache only cache feature which from base extractor + virtual size_t Dim() const { return base_extractor_->Dim(); } + virtual void SetFinished() { + base_extractor_->SetFinished(); + // read the last chunk data + Compute(); + } + virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { + base_extractor_->Reset(); + while (!cache_.empty()) { + cache_.pop(); + } + } + + private: + bool Compute(); + + std::mutex mutex_; + size_t max_size_; + std::queue> cache_; + std::unique_ptr base_extractor_; + std::condition_variable ready_feed_condition_; + std::condition_variable ready_read_condition_; + // DISALLOW_COPY_AND_ASSGIN(FeatureCache); +}; + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/feature_extractor_controller.h b/speechx/speechx/frontend/feature_extractor_controller.h new file mode 100644 index 00000000000..0544a1e298b --- /dev/null +++ b/speechx/speechx/frontend/feature_extractor_controller.h @@ -0,0 +1,13 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/speechx/speechx/frontend/feature_extractor_controller_impl.h b/speechx/speechx/frontend/feature_extractor_controller_impl.h new file mode 100644 index 00000000000..0544a1e298b --- /dev/null +++ b/speechx/speechx/frontend/feature_extractor_controller_impl.h @@ -0,0 +1,13 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/speechx/speechx/frontend/feature_extractor_interface.h b/speechx/speechx/frontend/feature_extractor_interface.h new file mode 100644 index 00000000000..3668fbda769 --- /dev/null +++ b/speechx/speechx/frontend/feature_extractor_interface.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/basic_types.h" +#include "kaldi/matrix/kaldi-vector.h" + +namespace ppspeech { + +class FeatureExtractorInterface { + public: + // accept input data, accept feature or raw waves which decided + // by the base_extractor + virtual void Accept(const kaldi::VectorBase& inputs) = 0; + // get the processed result + // the length of output = feature_row * feature_dim, + // the Matrix is squashed into Vector + virtual bool Read(kaldi::Vector* outputs) = 0; + // the Dim is the feature dim + virtual size_t Dim() const = 0; + virtual void SetFinished() = 0; + virtual bool IsFinished() const = 0; + virtual void Reset() = 0; +}; + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/linear_spectrogram.cc b/speechx/speechx/frontend/linear_spectrogram.cc new file mode 100644 index 00000000000..41bc8743a93 --- /dev/null +++ b/speechx/speechx/frontend/linear_spectrogram.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/linear_spectrogram.h" +#include "kaldi/base/kaldi-math.h" +#include "kaldi/matrix/matrix-functions.h" + +namespace ppspeech { + +using kaldi::int32; +using kaldi::BaseFloat; +using kaldi::Vector; +using kaldi::VectorBase; +using kaldi::Matrix; +using std::vector; + +LinearSpectrogram::LinearSpectrogram( + const LinearSpectrogramOptions& opts, + std::unique_ptr base_extractor) { + opts_ = opts; + base_extractor_ = std::move(base_extractor); + int32 window_size = opts.frame_opts.WindowSize(); + int32 window_shift = opts.frame_opts.WindowShift(); + fft_points_ = window_size; + chunk_sample_size_ = + static_cast(opts.streaming_chunk * opts.frame_opts.samp_freq); + hanning_window_.resize(window_size); + + double a = M_2PI / (window_size - 1); + hanning_window_energy_ = 0; + for (int i = 0; i < window_size; ++i) { + hanning_window_[i] = 0.5 - 0.5 * cos(a * i); + hanning_window_energy_ += hanning_window_[i] * hanning_window_[i]; + } + + dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz +} + +void LinearSpectrogram::Accept(const VectorBase& inputs) { + base_extractor_->Accept(inputs); +} + +bool LinearSpectrogram::Read(Vector* feats) { + Vector input_feats(chunk_sample_size_); + bool flag = base_extractor_->Read(&input_feats); + if (flag == false || input_feats.Dim() == 0) return false; + + vector input_feats_vec(input_feats.Dim()); + std::memcpy(input_feats_vec.data(), + input_feats.Data(), + input_feats.Dim() * sizeof(BaseFloat)); + vector> result; + Compute(input_feats_vec, result); + int32 feat_size = 0; + if (result.size() != 0) { + feat_size = result.size() * result[0].size(); + } + feats->Resize(feat_size); + // todo refactor (SimleGoat) + for (size_t idx = 0; idx < feat_size; ++idx) { + (*feats)(idx) = result[idx / dim_][idx % dim_]; + } + return true; +} + +void LinearSpectrogram::Hanning(vector* data) const { + CHECK_GE(data->size(), hanning_window_.size()); + + for (size_t i = 0; i < hanning_window_.size(); ++i) { + data->at(i) *= hanning_window_[i]; + } +} + +bool LinearSpectrogram::NumpyFft(vector* v, + vector* real, + vector* img) const { + Vector v_tmp; + v_tmp.Resize(v->size()); + std::memcpy(v_tmp.Data(), v->data(), sizeof(BaseFloat) * (v->size())); + RealFft(&v_tmp, true); + v->resize(v_tmp.Dim()); + std::memcpy(v->data(), v_tmp.Data(), sizeof(BaseFloat) * (v->size())); + + real->push_back(v->at(0)); + img->push_back(0); + for (int i = 1; i < v->size() / 2; i++) { + real->push_back(v->at(2 * i)); + img->push_back(v->at(2 * i + 1)); + } + real->push_back(v->at(1)); + img->push_back(0); + + return true; +} + +// Compute spectrogram feat +// todo: refactor later (SmileGoat) +bool LinearSpectrogram::Compute(const vector& waves, + vector>& feats) { + int num_samples = waves.size(); + const int& frame_length = opts_.frame_opts.WindowSize(); + const int& sample_rate = opts_.frame_opts.samp_freq; + const int& frame_shift = opts_.frame_opts.WindowShift(); + const int& fft_points = fft_points_; + const float scale = hanning_window_energy_ * sample_rate; + + if (num_samples < frame_length) { + return true; + } + + int num_frames = 1 + ((num_samples - frame_length) / frame_shift); + feats.resize(num_frames); + vector fft_real((fft_points_ / 2 + 1), 0); + vector fft_img((fft_points_ / 2 + 1), 0); + vector v(frame_length, 0); + vector power((fft_points / 2 + 1)); + + for (int i = 0; i < num_frames; ++i) { + vector data(waves.data() + i * frame_shift, + waves.data() + i * frame_shift + frame_length); + Hanning(&data); + fft_img.clear(); + fft_real.clear(); + v.assign(data.begin(), data.end()); + NumpyFft(&v, &fft_real, &fft_img); + + feats[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz + for (int j = 0; j < (fft_points / 2 + 1); ++j) { + power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j]; + feats[i][j] = power[j]; + + if (j == 0 || j == feats[0].size() - 1) { + feats[i][j] /= scale; + } else { + feats[i][j] *= (2.0 / scale); + } + + // log added eps=1e-14 + feats[i][j] = std::log(feats[i][j] + 1e-14); + } + } + return true; +} + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/linear_spectrogram.h b/speechx/speechx/frontend/linear_spectrogram.h new file mode 100644 index 00000000000..ffdfbbe9281 --- /dev/null +++ b/speechx/speechx/frontend/linear_spectrogram.h @@ -0,0 +1,68 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once + +#include "base/common.h" +#include "frontend/feature_extractor_interface.h" +#include "kaldi/feat/feature-window.h" + +namespace ppspeech { + +struct LinearSpectrogramOptions { + kaldi::FrameExtractionOptions frame_opts; + kaldi::BaseFloat streaming_chunk; + LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {} + + void Register(kaldi::OptionsItf* opts) { + opts->Register( + "streaming-chunk", &streaming_chunk, "streaming chunk size"); + frame_opts.Register(opts); + } +}; + +class LinearSpectrogram : public FeatureExtractorInterface { + public: + explicit LinearSpectrogram( + const LinearSpectrogramOptions& opts, + std::unique_ptr base_extractor); + virtual void Accept(const kaldi::VectorBase& inputs); + virtual bool Read(kaldi::Vector* feats); + // the dim_ is the dim of single frame feature + virtual size_t Dim() const { return dim_; } + virtual void SetFinished() { base_extractor_->SetFinished(); } + virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { base_extractor_->Reset(); } + + private: + void Hanning(std::vector* data) const; + bool Compute(const std::vector& waves, + std::vector>& feats); + bool NumpyFft(std::vector* v, + std::vector* real, + std::vector* img) const; + + kaldi::int32 fft_points_; + size_t dim_; + std::vector hanning_window_; + kaldi::BaseFloat hanning_window_energy_; + LinearSpectrogramOptions opts_; + std::unique_ptr base_extractor_; + int chunk_sample_size_; + DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); +}; + + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/mfcc.h b/speechx/speechx/frontend/mfcc.h new file mode 100644 index 00000000000..aa369655e53 --- /dev/null +++ b/speechx/speechx/frontend/mfcc.h @@ -0,0 +1,16 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// wrap the mfcc feat of kaldi, todo (SmileGoat) +#include "kaldi/feat/feature-mfcc.h" \ No newline at end of file diff --git a/speechx/speechx/frontend/normalizer.cc b/speechx/speechx/frontend/normalizer.cc new file mode 100644 index 00000000000..1adddb401b8 --- /dev/null +++ b/speechx/speechx/frontend/normalizer.cc @@ -0,0 +1,188 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "frontend/normalizer.h" +#include "kaldi/feat/cmvn.h" +#include "kaldi/util/kaldi-io.h" + +namespace ppspeech { + +using kaldi::Vector; +using kaldi::VectorBase; +using kaldi::BaseFloat; +using std::vector; +using kaldi::SubVector; +using std::unique_ptr; + +DecibelNormalizer::DecibelNormalizer( + const DecibelNormalizerOptions& opts, + std::unique_ptr base_extractor) { + base_extractor_ = std::move(base_extractor); + opts_ = opts; + dim_ = 1; +} + +void DecibelNormalizer::Accept(const kaldi::VectorBase& waves) { + base_extractor_->Accept(waves); +} + +bool DecibelNormalizer::Read(kaldi::Vector* waves) { + if (base_extractor_->Read(waves) == false || waves->Dim() == 0) { + return false; + } + Compute(waves); + return true; +} + +bool DecibelNormalizer::Compute(VectorBase* waves) const { + // calculate db rms + BaseFloat rms_db = 0.0; + BaseFloat mean_square = 0.0; + BaseFloat gain = 0.0; + BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1)); + + vector samples; + samples.resize(waves->Dim()); + for (size_t i = 0; i < samples.size(); ++i) { + samples[i] = (*waves)(i); + } + + // square + for (auto& d : samples) { + if (opts_.convert_int_float) { + d = d * wave_float_normlization; + } + mean_square += d * d; + } + + // mean + mean_square /= samples.size(); + rms_db = 10 * std::log10(mean_square); + gain = opts_.target_db - rms_db; + + if (gain > opts_.max_gain_db) { + LOG(ERROR) + << "Unable to normalize segment to " << opts_.target_db << "dB," + << "because the the probable gain have exceeds opts_.max_gain_db" + << opts_.max_gain_db << "dB."; + return false; + } + + // Note that this is an in-place transformation. + for (auto& item : samples) { + // python item *= 10.0 ** (gain / 20.0) + item *= std::pow(10.0, gain / 20.0); + } + + std::memcpy( + waves->Data(), samples.data(), sizeof(BaseFloat) * samples.size()); + return true; +} + +CMVN::CMVN(std::string cmvn_file, + unique_ptr base_extractor) + : var_norm_(true) { + base_extractor_ = std::move(base_extractor); + bool binary; + kaldi::Input ki(cmvn_file, &binary); + stats_.Read(ki.Stream(), binary); + dim_ = stats_.NumCols() - 1; +} + +void CMVN::Accept(const kaldi::VectorBase& inputs) { + base_extractor_->Accept(inputs); + return; +} + +bool CMVN::Read(kaldi::Vector* feats) { + if (base_extractor_->Read(feats) == false) { + return false; + } + Compute(feats); + return true; +} + +// feats contain num_frames feature. +void CMVN::Compute(VectorBase* feats) const { + KALDI_ASSERT(feats != NULL); + int32 dim = stats_.NumCols() - 1; + if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || + feats->Dim() % dim != 0) { + KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x' + << stats_.NumCols() << ", feats " << feats->Dim() << 'x'; + } + if (stats_.NumRows() == 1 && var_norm_) { + KALDI_ERR + << "You requested variance normalization but no variance stats_ " + << "are supplied."; + } + + double count = stats_(0, dim); + // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when + // computing an offset and representing it as stats_, we use a count of one. + if (count < 1.0) + KALDI_ERR << "Insufficient stats_ for cepstral mean and variance " + "normalization: " + << "count = " << count; + + if (!var_norm_) { + Vector offset(feats->Dim()); + SubVector mean_stats(stats_.RowData(0), dim); + Vector mean_stats_apply(feats->Dim()); + // fill the datat of mean_stats in mean_stats_appy whose dim is equal + // with the dim of feature. + // the dim of feats = dim * num_frames; + for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) { + SubVector stats_tmp(mean_stats_apply.Data() + dim * idx, + dim); + stats_tmp.CopyFromVec(mean_stats); + } + offset.AddVec(-1.0 / count, mean_stats_apply); + feats->AddVec(1.0, offset); + return; + } + // norm(0, d) = mean offset; + // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). + kaldi::Matrix norm(2, feats->Dim()); + for (int32 d = 0; d < dim; d++) { + double mean, offset, scale; + mean = stats_(0, d) / count; + double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20; + if (var < floor) { + KALDI_WARN << "Flooring cepstral variance from " << var << " to " + << floor; + var = floor; + } + scale = 1.0 / sqrt(var); + if (scale != scale || 1 / scale == 0.0) + KALDI_ERR + << "NaN or infinity in cepstral mean/variance computation"; + offset = -(mean * scale); + for (int32 d_skip = d; d_skip < feats->Dim();) { + norm(0, d_skip) = offset; + norm(1, d_skip) = scale; + d_skip = d_skip + dim; + } + } + // Apply the normalization. + feats->MulElements(norm.Row(1)); + feats->AddVec(1.0, norm.Row(0)); +} + +void CMVN::ApplyCMVN(kaldi::MatrixBase* feats) { + ApplyCmvn(stats_, var_norm_, feats); +} + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/normalizer.h b/speechx/speechx/frontend/normalizer.h new file mode 100644 index 00000000000..352d1e1677e --- /dev/null +++ b/speechx/speechx/frontend/normalizer.h @@ -0,0 +1,89 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once + +#include "base/common.h" +#include "frontend/feature_extractor_interface.h" +#include "kaldi/matrix/kaldi-matrix.h" +#include "kaldi/util/options-itf.h" + +namespace ppspeech { + +struct DecibelNormalizerOptions { + float target_db; + float max_gain_db; + bool convert_int_float; + DecibelNormalizerOptions() + : target_db(-20), max_gain_db(300.0), convert_int_float(false) {} + + void Register(kaldi::OptionsItf* opts) { + opts->Register( + "target-db", &target_db, "target db for db normalization"); + opts->Register( + "max-gain-db", &max_gain_db, "max gain db for db normalization"); + opts->Register("convert-int-float", + &convert_int_float, + "if convert int samples to float"); + } +}; + +class DecibelNormalizer : public FeatureExtractorInterface { + public: + explicit DecibelNormalizer( + const DecibelNormalizerOptions& opts, + std::unique_ptr base_extractor); + virtual void Accept(const kaldi::VectorBase& waves); + virtual bool Read(kaldi::Vector* waves); + // noramlize audio, the dim is 1. + virtual size_t Dim() const { return dim_; } + virtual void SetFinished() { base_extractor_->SetFinished(); } + virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { base_extractor_->Reset(); } + + private: + bool Compute(kaldi::VectorBase* waves) const; + DecibelNormalizerOptions opts_; + size_t dim_; + std::unique_ptr base_extractor_; + kaldi::Vector waveform_; +}; + + +class CMVN : public FeatureExtractorInterface { + public: + explicit CMVN(std::string cmvn_file, + std::unique_ptr base_extractor); + virtual void Accept(const kaldi::VectorBase& inputs); + + // the length of feats = feature_row * feature_dim, + // the Matrix is squashed into Vector + virtual bool Read(kaldi::Vector* feats); + // the dim_ is the feautre dim. + virtual size_t Dim() const { return dim_; } + virtual void SetFinished() { base_extractor_->SetFinished(); } + virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { base_extractor_->Reset(); } + + private: + void Compute(kaldi::VectorBase* feats) const; + void ApplyCMVN(kaldi::MatrixBase* feats); + kaldi::Matrix stats_; + std::unique_ptr base_extractor_; + size_t dim_; + bool var_norm_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/raw_audio.cc b/speechx/speechx/frontend/raw_audio.cc new file mode 100644 index 00000000000..21f64362825 --- /dev/null +++ b/speechx/speechx/frontend/raw_audio.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/raw_audio.h" +#include "kaldi/base/timer.h" + +namespace ppspeech { + +using kaldi::BaseFloat; +using kaldi::VectorBase; +using kaldi::Vector; + +RawAudioCache::RawAudioCache(int buffer_size) + : finished_(false), data_length_(0), start_(0), timeout_(1) { + ring_buffer_.resize(buffer_size); +} + +void RawAudioCache::Accept(const VectorBase& waves) { + std::unique_lock lock(mutex_); + while (data_length_ + waves.Dim() > ring_buffer_.size()) { + ready_feed_condition_.wait(lock); + } + for (size_t idx = 0; idx < waves.Dim(); ++idx) { + int32 buffer_idx = (idx + start_) % ring_buffer_.size(); + ring_buffer_[buffer_idx] = waves(idx); + } + data_length_ += waves.Dim(); +} + +bool RawAudioCache::Read(Vector* waves) { + size_t chunk_size = waves->Dim(); + kaldi::Timer timer; + std::unique_lock lock(mutex_); + while (chunk_size > data_length_) { + // when audio is empty and no more data feed + // ready_read_condition will block in dead lock. so replace with + // timeout_ + // ready_read_condition_.wait(lock); + int32 elapsed = static_cast(timer.Elapsed() * 1000); + if (elapsed > timeout_) { + if (finished_ == true) { // read last chunk data + break; + } + if (chunk_size > data_length_) { + return false; + } + } + usleep(100); // sleep 0.1 ms + } + + // read last chunk data + if (chunk_size > data_length_) { + chunk_size = data_length_; + waves->Resize(chunk_size); + } + + for (size_t idx = 0; idx < chunk_size; ++idx) { + int buff_idx = (start_ + idx) % ring_buffer_.size(); + waves->Data()[idx] = ring_buffer_[buff_idx]; + } + data_length_ -= chunk_size; + start_ = (start_ + chunk_size) % ring_buffer_.size(); + ready_feed_condition_.notify_one(); + return true; +} + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/raw_audio.h b/speechx/speechx/frontend/raw_audio.h new file mode 100644 index 00000000000..ce75c137cf3 --- /dev/null +++ b/speechx/speechx/frontend/raw_audio.h @@ -0,0 +1,85 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#pragma once + +#include "base/common.h" +#include "frontend/feature_extractor_interface.h" + +#pragma once + +namespace ppspeech { + +class RawAudioCache : public FeatureExtractorInterface { + public: + explicit RawAudioCache(int buffer_size = kint16max); + virtual void Accept(const kaldi::VectorBase& waves); + virtual bool Read(kaldi::Vector* waves); + // the audio dim is 1 + virtual size_t Dim() const { return 1; } + virtual void SetFinished() { + std::lock_guard lock(mutex_); + finished_ = true; + } + virtual bool IsFinished() const { return finished_; } + virtual void Reset() { + start_ = 0; + data_length_ = 0; + finished_ = false; + } + + private: + std::vector ring_buffer_; + size_t start_; + size_t data_length_; + bool finished_; + mutable std::mutex mutex_; + std::condition_variable ready_feed_condition_; + kaldi::int32 timeout_; + + DISALLOW_COPY_AND_ASSIGN(RawAudioCache); +}; + +// it is a datasource for testing different frontend module. +// it accepts waves or feats. +class RawDataCache : public FeatureExtractorInterface { + public: + explicit RawDataCache() { finished_ = false; } + virtual void Accept(const kaldi::VectorBase& inputs) { + data_ = inputs; + } + virtual bool Read(kaldi::Vector* feats) { + if (data_.Dim() == 0) { + return false; + } + (*feats) = data_; + data_.Resize(0); + return true; + } + virtual size_t Dim() const { return dim_; } + virtual void SetFinished() { finished_ = true; } + virtual bool IsFinished() const { return finished_; } + void SetDim(int32 dim) { dim_ = dim; } + virtual void Reset() { finished_ = true; } + + private: + kaldi::Vector data_; + bool finished_; + int32 dim_; + + DISALLOW_COPY_AND_ASSIGN(RawDataCache); +}; + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/window.h b/speechx/speechx/frontend/window.h new file mode 100644 index 00000000000..70d6307ec0c --- /dev/null +++ b/speechx/speechx/frontend/window.h @@ -0,0 +1,15 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// extract the window of kaldi feat. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc new file mode 100644 index 00000000000..42d1d2af4c7 --- /dev/null +++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc @@ -0,0 +1,1020 @@ +// decoder/lattice-faster-decoder.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2018 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/lattice-faster-decoder.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +// instantiate this class once for each thing you have to decode. +template +LatticeFasterDecoderTpl::LatticeFasterDecoderTpl( + const FST &fst, const LatticeFasterDecoderConfig &config) + : fst_(&fst), + delete_fst_(false), + config_(config), + num_toks_(0), + token_pool_(config.memory_pool_tokens_block_size), + forward_link_pool_(config.memory_pool_links_block_size) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + +template +LatticeFasterDecoderTpl::LatticeFasterDecoderTpl( + const LatticeFasterDecoderConfig &config, FST *fst) + : fst_(fst), + delete_fst_(true), + config_(config), + num_toks_(0), + token_pool_(config.memory_pool_tokens_block_size), + forward_link_pool_(config.memory_pool_links_block_size) { + config.Check(); + toks_.SetSize(1000); // just so on the first frame we do something reasonable. +} + +template +LatticeFasterDecoderTpl::~LatticeFasterDecoderTpl() { + DeleteElems(toks_.Clear()); + ClearActiveTokens(); + if (delete_fst_) delete fst_; +} + +template +void LatticeFasterDecoderTpl::InitDecoding() { + // clean up from last time: + DeleteElems(toks_.Clear()); + cost_offsets_.clear(); + ClearActiveTokens(); + warned_ = false; + num_toks_ = 0; + decoding_finalized_ = false; + final_costs_.clear(); + StateId start_state = fst_->Start(); + KALDI_ASSERT(start_state != fst::kNoStateId); + active_toks_.resize(1); + Token *start_tok = + new (token_pool_.Allocate()) Token(0.0, 0.0, NULL, NULL, NULL); + active_toks_[0].toks = start_tok; + toks_.Insert(start_state, start_tok); + num_toks_++; + ProcessNonemitting(config_.beam); +} + +// Returns true if any kind of traceback is available (not necessarily from +// a final state). It should only very rarely return false; this indicates +// an unusual search error. +template +bool LatticeFasterDecoderTpl::Decode(DecodableInterface *decodable) { + InitDecoding(); + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + AdvanceDecoding(decodable); + FinalizeDecoding(); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !active_toks_.empty() && active_toks_.back().toks != NULL; +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) const { + Lattice raw_lat; + GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, olat); + return (olat->NumStates() != 0); +} + + +// Outputs an FST corresponding to the raw, state-level lattice +template +bool LatticeFasterDecoderTpl::GetRawLattice( + Lattice *ofst, + bool use_final_probs) const { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (decoding_finalized_ ? final_costs_ : final_costs_local); + if (!decoding_finalized_ && use_final_probs) + ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + const int32 bucket_count = num_toks_/2 + 3; + unordered_map tok_map(bucket_count); + // First create all states. + std::vector token_list; + for (int32 f = 0; f <= num_frames; f++) { + if (active_toks_[f].toks == NULL) { + KALDI_WARN << "GetRawLattice: no tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + TopSortTokens(active_toks_[f].toks, &token_list); + for (size_t i = 0; i < token_list.size(); i++) + if (token_list[i] != NULL) + tok_map[token_list[i]] = ofst->AddState(); + } + // The next statement sets the start state of the output FST. Because we + // topologically sorted the tokens, state zero must be the start-state. + ofst->SetStart(0); + + KALDI_VLOG(4) << "init:" << num_toks_/2 + 3 << " buckets:" + << tok_map.bucket_count() << " load:" << tok_map.load_factor() + << " max:" << tok_map.max_load_factor(); + // Now create all arcs. + for (int32 f = 0; f <= num_frames; f++) { + for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) { + StateId cur_state = tok_map[tok]; + for (ForwardLinkT *l = tok->links; + l != NULL; + l = l->next) { + typename unordered_map::const_iterator + iter = tok_map.find(l->next_tok); + StateId nextstate = iter->second; + KALDI_ASSERT(iter != tok_map.end()); + BaseFloat cost_offset = 0.0; + if (l->ilabel != 0) { // emitting.. + KALDI_ASSERT(f >= 0 && f < cost_offsets_.size()); + cost_offset = cost_offsets_[f]; + } + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + if (f == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + } + return (ofst->NumStates() > 0); +} + + +// This function is now deprecated, since now we do determinization from outside +// the LatticeFasterDecoder class. Outputs an FST corresponding to the +// lattice-determinized lattice (one path per word sequence). +template +bool LatticeFasterDecoderTpl::GetLattice(CompactLattice *ofst, + bool use_final_probs) const { + Lattice raw_fst; + GetRawLattice(&raw_fst, use_final_probs); + Invert(&raw_fst); // make it so word labels are on the input. + // (in phase where we get backward-costs). + fst::ILabelCompare ilabel_comp; + ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes + // lattice-determinization more efficient. + + fst::DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = config_.det_opts.max_mem; + + DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts); + raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed. + Connect(ofst); // Remove unreachable states... there might be + // a small number of these, in some cases. + // Note: if something went wrong and the raw lattice was empty, + // we should still get to this point in the code without warnings or failures. + return (ofst->NumStates() != 0); +} + +template +void LatticeFasterDecoderTpl::PossiblyResizeHash(size_t num_toks) { + size_t new_sz = static_cast(static_cast(num_toks) + * config_.hash_ratio); + if (new_sz > toks_.Size()) { + toks_.SetSize(new_sz); + } +} + +/* + A note on the definition of extra_cost. + + extra_cost is used in pruning tokens, to save memory. + + extra_cost can be thought of as a beta (backward) cost assuming + we had set the betas on currently-active tokens to all be the negative + of the alphas for those tokens. (So all currently active tokens would + be on (tied) best paths). + + We can use the extra_cost to accurately prune away tokens that we know will + never appear in the lattice. If the extra_cost is greater than the desired + lattice beam, the token would provably never appear in the lattice, so we can + prune away the token. + + (Note: we don't update all the extra_costs every time we update a frame; we + only do it every 'config_.prune_interval' frames). + */ + +// FindOrAddToken either locates a token in hash of toks_, +// or if necessary inserts a new, empty token (i.e. with no forward links) +// for the current frame. [note: it's inserted if necessary into hash toks_ +// and also into the singly linked list of tokens active on this frame +// (whose head is at active_toks_[frame]). +template +inline typename LatticeFasterDecoderTpl::Elem* +LatticeFasterDecoderTpl::FindOrAddToken( + StateId state, int32 frame_plus_one, BaseFloat tot_cost, + Token *backpointer, bool *changed) { + // Returns the Token pointer. Sets "changed" (if non-NULL) to true + // if the token was newly created or the cost changed. + KALDI_ASSERT(frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + Elem *e_found = toks_.Insert(state, NULL); + if (e_found->val == NULL) { // no such token presently. + const BaseFloat extra_cost = 0.0; + // tokens on the currently final frame have zero extra_cost + // as any of them could end up + // on the winning path. + Token *new_tok = new (token_pool_.Allocate()) + Token(tot_cost, extra_cost, NULL, toks, backpointer); + // NULL: no forward links yet + toks = new_tok; + num_toks_++; + e_found->val = new_tok; + if (changed) *changed = true; + return e_found; + } else { + Token *tok = e_found->val; // There is an existing Token for this state. + if (tok->tot_cost > tot_cost) { // replace old token + tok->tot_cost = tot_cost; + // SetBackpointer() just does tok->backpointer = backpointer in + // the case where Token == BackpointerToken, else nothing. + tok->SetBackpointer(backpointer); + // we don't allocate a new token, the old stays linked in active_toks_ + // we only replace the tot_cost + // in the current frame, there are no forward links (and no extra_cost) + // only in ProcessNonemitting we have to delete forward links + // in case we visit a state for the second time + // those forward links, that lead to this replaced token before: + // they remain and will hopefully be pruned later (PruneForwardLinks...) + if (changed) *changed = true; + } else { + if (changed) *changed = false; + } + return e_found; + } +} + +// prunes outgoing links for all tokens in active_toks_[frame] +// it's called by PruneActiveTokens +// all links, that have link_extra_cost > lattice_beam are pruned +template +void LatticeFasterDecoderTpl::PruneForwardLinks( + int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, BaseFloat delta) { + // delta is the amount by which the extra_costs must change + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + + *extra_costs_changed = false; + *links_pruned = false; + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + if (active_toks_[frame_plus_one].toks == NULL) { // empty list; should not happen. + if (!warned_) { + KALDI_WARN << "No tokens alive [doing pruning].. warning first " + "time only for each utterance\n"; + warned_ = true; + } + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost for tok. + BaseFloat tok_extra_cost = std::numeric_limits::infinity(); + // tok_extra_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + // link_exta_cost is the difference in score between the best paths + // through link source state and through link destination state + KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + forward_link_pool_.Free(link); + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; // move to next link + link = link->next; + } + } // for all outgoing links + if (fabs(tok_extra_cost - tok->extra_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->extra_cost = tok_extra_cost; + // will be +infinity or <= lattice_beam_. + // infinity indicates, that no forward link survived pruning + } // for all Token on active_toks_[frame] + if (changed) *extra_costs_changed = true; + + // Note: it's theoretically possible that aggressive compiler + // optimizations could cause an infinite loop here for small delta and + // high-dynamic-range scores. + } // while changed +} + +// PruneForwardLinksFinal is a version of PruneForwardLinks that we call +// on the final frame. If there are final tokens active, it uses +// the final-probs for pruning, otherwise it treats all tokens as final. +template +void LatticeFasterDecoderTpl::PruneForwardLinksFinal() { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame_plus_one = active_toks_.size() - 1; + + if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen. + KALDI_WARN << "No tokens alive at end of file"; + + typedef typename unordered_map::const_iterator IterType; + ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_); + decoding_finalized_ = true; + // We call DeleteElems() as a nicety, not because it's really necessary; + // otherwise there would be a time, after calling PruneTokensForFrame() on the + // final frame, when toks_.GetList() or toks_.Clear() would contain pointers + // to nonexistent tokens. + DeleteElems(toks_.Clear()); + + // Now go through tokens on this frame, pruning forward links... may have to + // iterate a few times until there is no more change, because the list is not + // in topological order. This is a modified version of the code in + // PruneForwardLinks, but here we also take account of the final-probs. + bool changed = true; + BaseFloat delta = 1.0e-05; + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame_plus_one].toks; + tok != NULL; tok = tok->next) { + ForwardLinkT *link, *prev_link = NULL; + // will recompute tok_extra_cost. It has a term in it that corresponds + // to the "final-prob", so instead of initializing tok_extra_cost to infinity + // below we set it to the difference between the (score+final_prob) of this token, + // and the best such (score+final_prob). + BaseFloat final_cost; + if (final_costs_.empty()) { + final_cost = 0.0; + } else { + IterType iter = final_costs_.find(tok); + if (iter != final_costs_.end()) + final_cost = iter->second; + else + final_cost = std::numeric_limits::infinity(); + } + BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_; + // tok_extra_cost will be a "min" over either directly being final, or + // being indirectly final through other links, and the loop below may + // decrease its value: + for (link = tok->links; link != NULL; ) { + // See if we need to excise this link... + Token *next_tok = link->next_tok; + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); + if (link_extra_cost > config_.lattice_beam) { // excise link + ForwardLinkT *next_link = link->next; + if (prev_link != NULL) prev_link->next = next_link; + else tok->links = next_link; + forward_link_pool_.Free(link); + link = next_link; // advance link but leave prev_link the same. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. + if (link_extra_cost < -0.01) + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + link_extra_cost = 0.0; + } + if (link_extra_cost < tok_extra_cost) + tok_extra_cost = link_extra_cost; + prev_link = link; + link = link->next; + } + } + // prune away tokens worse than lattice_beam above best path. This step + // was not necessary in the non-final case because then, this case + // showed up as having no forward links. Here, the tok_extra_cost has + // an extra component relating to the final-prob. + if (tok_extra_cost > config_.lattice_beam) + tok_extra_cost = std::numeric_limits::infinity(); + // to be pruned in PruneTokensForFrame + + if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) + changed = true; + tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. + } + } // while changed +} + +template +BaseFloat LatticeFasterDecoderTpl::FinalRelativeCost() const { + if (!decoding_finalized_) { + BaseFloat relative_cost; + ComputeFinalCosts(NULL, &relative_cost, NULL); + return relative_cost; + } else { + // we're not allowed to call that function if FinalizeDecoding() has + // been called; return a cached value. + return final_relative_cost_; + } +} + + +// Prune away any tokens on this frame that have no forward links. +// [we don't do this in PruneForwardLinks because it would give us +// a problem with dangling pointers]. +// It's called by PruneActiveTokens if any forward links have been pruned +template +void LatticeFasterDecoderTpl::PruneTokensForFrame(int32 frame_plus_one) { + KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size()); + Token *&toks = active_toks_[frame_plus_one].toks; + if (toks == NULL) + KALDI_WARN << "No tokens alive [doing pruning]"; + Token *tok, *next_tok, *prev_tok = NULL; + for (tok = toks; tok != NULL; tok = next_tok) { + next_tok = tok->next; + if (tok->extra_cost == std::numeric_limits::infinity()) { + // token is unreachable from end of graph; (no forward links survived) + // excise tok from list and delete tok. + if (prev_tok != NULL) prev_tok->next = tok->next; + else toks = tok->next; + token_pool_.Free(tok); + num_toks_--; + } else { // fetch next Token + prev_tok = tok; + } + } +} + +// Go backwards through still-alive tokens, pruning them, starting not from +// the current frame (where we want to keep all tokens) but from the frame before +// that. We go backwards through the frames and stop when we reach a point +// where the delta-costs are not changing (and the delta controls when we consider +// a cost to have "not changed"). +template +void LatticeFasterDecoderTpl::PruneActiveTokens(BaseFloat delta) { + int32 cur_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // The index "f" below represents a "frame plus one", i.e. you'd have to subtract + // one to get the corresponding index for the decodable object. + for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) { + // Reason why we need to prune forward links in this situation: + // (1) we have never pruned them (new TokenList) + // (2) we have not yet pruned the forward links to the next f, + // after any of those tokens have changed their extra_cost. + if (active_toks_[f].must_prune_forward_links) { + bool extra_costs_changed = false, links_pruned = false; + PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta); + if (extra_costs_changed && f > 0) // any token has changed extra_cost + active_toks_[f-1].must_prune_forward_links = true; + if (links_pruned) // any link was pruned + active_toks_[f].must_prune_tokens = true; + active_toks_[f].must_prune_forward_links = false; // job done + } + if (f+1 < cur_frame_plus_one && // except for last f (no forward links) + active_toks_[f+1].must_prune_tokens) { + PruneTokensForFrame(f+1); + active_toks_[f+1].must_prune_tokens = false; + } + } + KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +template +void LatticeFasterDecoderTpl::ComputeFinalCosts( + unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const { + KALDI_ASSERT(!decoding_finalized_); + if (final_costs != NULL) + final_costs->clear(); + const Elem *final_toks = toks_.GetList(); + BaseFloat infinity = std::numeric_limits::infinity(); + BaseFloat best_cost = infinity, + best_cost_with_final = infinity; + + while (final_toks != NULL) { + StateId state = final_toks->key; + Token *tok = final_toks->val; + const Elem *next = final_toks->tail; + BaseFloat final_cost = fst_->Final(state).Value(); + BaseFloat cost = tok->tot_cost, + cost_with_final = cost + final_cost; + best_cost = std::min(cost, best_cost); + best_cost_with_final = std::min(cost_with_final, best_cost_with_final); + if (final_costs != NULL && final_cost != infinity) + (*final_costs)[tok] = final_cost; + final_toks = next; + } + if (final_relative_cost != NULL) { + if (best_cost == infinity && best_cost_with_final == infinity) { + // Likely this will only happen if there are no tokens surviving. + // This seems the least bad way to handle it. + *final_relative_cost = infinity; + } else { + *final_relative_cost = best_cost_with_final - best_cost; + } + } + if (final_best_cost != NULL) { + if (best_cost_with_final != infinity) { // final-state exists. + *final_best_cost = best_cost_with_final; + } else { // no final-state exists. + *final_best_cost = best_cost; + } + } +} + +template +void LatticeFasterDecoderTpl::AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames) { + if (std::is_same >::value) { + // if the type 'FST' is the FST base-class, then see if the FST type of fst_ + // is actually VectorFst or ConstFst. If so, call the AdvanceDecoding() + // function after casting *this to the more specific type. + if (fst_->Type() == "const") { + LatticeFasterDecoderTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } else if (fst_->Type() == "vector") { + LatticeFasterDecoderTpl, Token> *this_cast = + reinterpret_cast, Token>* >(this); + this_cast->AdvanceDecoding(decodable, max_num_frames); + return; + } + } + + + KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ && + "You must call InitDecoding() before AdvanceDecoding"); + int32 num_frames_ready = decodable->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= NumFramesDecoded()); + int32 target_frames_decoded = num_frames_ready; + if (max_num_frames >= 0) + target_frames_decoded = std::min(target_frames_decoded, + NumFramesDecoded() + max_num_frames); + while (NumFramesDecoded() < target_frames_decoded) { + if (NumFramesDecoded() % config_.prune_interval == 0) { + PruneActiveTokens(config_.lattice_beam * config_.prune_scale); + } + BaseFloat cost_cutoff = ProcessEmitting(decodable); + ProcessNonemitting(cost_cutoff); + } +} + +// FinalizeDecoding() is a version of PruneActiveTokens that we call +// (optionally) on the final frame. Takes into account the final-prob of +// tokens. This function used to be called PruneActiveTokensFinal(). +template +void LatticeFasterDecoderTpl::FinalizeDecoding() { + int32 final_frame_plus_one = NumFramesDecoded(); + int32 num_toks_begin = num_toks_; + // PruneForwardLinksFinal() prunes final frame (with final-probs), and + // sets decoding_finalized_. + PruneForwardLinksFinal(); + for (int32 f = final_frame_plus_one - 1; f >= 0; f--) { + bool b1, b2; // values not used. + BaseFloat dontcare = 0.0; // delta of zero means we must always update + PruneForwardLinks(f, &b1, &b2, dontcare); + PruneTokensForFrame(f + 1); + } + PruneTokensForFrame(0); + KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin + << " to " << num_toks_; +} + +/// Gets the weight cutoff. Also counts the active tokens. +template +BaseFloat LatticeFasterDecoderTpl::GetCutoff(Elem *list_head, size_t *tok_count, + BaseFloat *adaptive_beam, Elem **best_elem) { + BaseFloat best_weight = std::numeric_limits::infinity(); + // positive == high cost == bad. + size_t count = 0; + if (config_.max_active == std::numeric_limits::max() && + config_.min_active == 0) { + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = static_cast(e->val->tot_cost); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + if (adaptive_beam != NULL) *adaptive_beam = config_.beam; + return best_weight + config_.beam; + } else { + tmp_array_.clear(); + for (Elem *e = list_head; e != NULL; e = e->tail, count++) { + BaseFloat w = e->val->tot_cost; + tmp_array_.push_back(w); + if (w < best_weight) { + best_weight = w; + if (best_elem) *best_elem = e; + } + } + if (tok_count != NULL) *tok_count = count; + + BaseFloat beam_cutoff = best_weight + config_.beam, + min_active_cutoff = std::numeric_limits::infinity(), + max_active_cutoff = std::numeric_limits::infinity(); + + KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded() + << " is " << tmp_array_.size(); + + if (tmp_array_.size() > static_cast(config_.max_active)) { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.max_active, + tmp_array_.end()); + max_active_cutoff = tmp_array_[config_.max_active]; + } + if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam. + if (adaptive_beam) + *adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta; + return max_active_cutoff; + } + if (tmp_array_.size() > static_cast(config_.min_active)) { + if (config_.min_active == 0) min_active_cutoff = best_weight; + else { + std::nth_element(tmp_array_.begin(), + tmp_array_.begin() + config_.min_active, + tmp_array_.size() > static_cast(config_.max_active) ? + tmp_array_.begin() + config_.max_active : + tmp_array_.end()); + min_active_cutoff = tmp_array_[config_.min_active]; + } + } + if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam. + if (adaptive_beam) + *adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta; + return min_active_cutoff; + } else { + *adaptive_beam = config_.beam; + return beam_cutoff; + } + } +} + +template +BaseFloat LatticeFasterDecoderTpl::ProcessEmitting( + DecodableInterface *decodable) { + KALDI_ASSERT(active_toks_.size() > 0); + int32 frame = active_toks_.size() - 1; // frame is the frame-index + // (zero-based) used to get likelihoods + // from the decodable object. + active_toks_.resize(active_toks_.size() + 1); + + Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_ + // in simple-decoder.h. Removes the Elems from + // being indexed in the hash in toks_. + Elem *best_elem = NULL; + BaseFloat adaptive_beam; + size_t tok_cnt; + BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem); + KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is " + << adaptive_beam; + + PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. + + BaseFloat next_cutoff = std::numeric_limits::infinity(); + // pruning "online" before having seen all tokens + + BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good + // dynamic range. + + + // First process the best token to get a hopefully + // reasonably tight bound on the next cutoff. The only + // products of the next block are "next_cutoff" and "cost_offset". + if (best_elem) { + StateId state = best_elem->key; + Token *tok = best_elem->val; + cost_offset = - tok->tot_cost; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat new_weight = arc.weight.Value() + cost_offset - + decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost; + if (new_weight + adaptive_beam < next_cutoff) + next_cutoff = new_weight + adaptive_beam; + } + } + } + + // Store the offset on the acoustic likelihoods that we're applying. + // Could just do cost_offsets_.push_back(cost_offset), but we + // do it this way as it's more robust to future code changes. + cost_offsets_.resize(frame + 1, 0.0); + cost_offsets_[frame] = cost_offset; + + // the tokens are now owned here, in final_toks, and the hash is empty. + // 'owned' is a complex thing here; the point is we need to call DeleteElem + // on each elem 'e' to let toks_ know we're done with them. + for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) { + // loop this way because we delete "e" as we go. + StateId state = e->key; + Token *tok = e->val; + if (tok->tot_cost <= cur_cutoff) { + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { // propagate.. + BaseFloat ac_cost = cost_offset - + decodable->LogLikelihood(frame, arc.ilabel), + graph_cost = arc.weight.Value(), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + if (tot_cost >= next_cutoff) continue; + else if (tot_cost + adaptive_beam < next_cutoff) + next_cutoff = tot_cost + adaptive_beam; // prune by best current token + // Note: the frame indexes into active_toks_ are one-based, + // hence the + 1. + Elem *e_next = FindOrAddToken(arc.nextstate, + frame + 1, tot_cost, tok, NULL); + // NULL: no change indicator needed + + // Add ForwardLink from tok to next_tok (put on head of list tok->links) + tok->links = new (forward_link_pool_.Allocate()) + ForwardLinkT(e_next->val, arc.ilabel, arc.olabel, graph_cost, + ac_cost, tok->links); + } + } // for all arcs + } + e_tail = e->tail; + toks_.Delete(e); // delete Elem + } + return next_cutoff; +} + +// static inline +template +void LatticeFasterDecoderTpl::DeleteForwardLinks(Token *tok) { + ForwardLinkT *l = tok->links, *m; + while (l != NULL) { + m = l->next; + forward_link_pool_.Free(l); + l = m; + } + tok->links = NULL; +} + + +template +void LatticeFasterDecoderTpl::ProcessNonemitting(BaseFloat cutoff) { + KALDI_ASSERT(!active_toks_.empty()); + int32 frame = static_cast(active_toks_.size()) - 2; + // Note: "frame" is the time-index we just processed, or -1 if + // we are processing the nonemitting transitions before the + // first frame (called from InitDecoding()). + + // Processes nonemitting arcs for one frame. Propagates within toks_. + // Note-- this queue structure is not very optimal as + // it may cause us to process states unnecessarily (e.g. more than once), + // but in the baseline code, turning this vector into a set to fix this + // problem did not improve overall speed. + + KALDI_ASSERT(queue_.empty()); + + if (toks_.GetList() == NULL) { + if (!warned_) { + KALDI_WARN << "Error, no surviving tokens: frame is " << frame; + warned_ = true; + } + } + + for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { + StateId state = e->key; + if (fst_->NumInputEpsilons(state) != 0) + queue_.push_back(e); + } + + while (!queue_.empty()) { + const Elem *e = queue_.back(); + queue_.pop_back(); + + StateId state = e->key; + Token *tok = e->val; // would segfault if e is a NULL pointer but this can't happen. + BaseFloat cur_cost = tok->tot_cost; + if (cur_cost >= cutoff) // Don't bother processing successors. + continue; + // If "tok" has any existing forward links, delete them, + // because we're about to regenerate them. This is a kind + // of non-optimality (remember, this is the simple decoder), + // but since most states are emitting it's not a huge issue. + DeleteForwardLinks(tok); // necessary when re-visiting + tok->links = NULL; + for (fst::ArcIterator aiter(*fst_, state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel == 0) { // propagate nonemitting only... + BaseFloat graph_cost = arc.weight.Value(), + tot_cost = cur_cost + graph_cost; + if (tot_cost < cutoff) { + bool changed; + + Elem *e_new = FindOrAddToken(arc.nextstate, frame + 1, tot_cost, + tok, &changed); + + tok->links = new (forward_link_pool_.Allocate()) ForwardLinkT( + e_new->val, 0, arc.olabel, graph_cost, 0, tok->links); + + // "changed" tells us whether the new token has a different + // cost from before, or is new [if so, add into queue]. + if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0) + queue_.push_back(e_new); + } + } + } // for all arcs + } // while queue not empty +} + + +template +void LatticeFasterDecoderTpl::DeleteElems(Elem *list) { + for (Elem *e = list, *e_tail; e != NULL; e = e_tail) { + e_tail = e->tail; + toks_.Delete(e); + } +} + +template +void LatticeFasterDecoderTpl::ClearActiveTokens() { // a cleanup routine, at utt end/begin + for (size_t i = 0; i < active_toks_.size(); i++) { + // Delete all tokens alive on this frame, and any forward + // links they may have. + for (Token *tok = active_toks_[i].toks; tok != NULL; ) { + DeleteForwardLinks(tok); + Token *next_tok = tok->next; + token_pool_.Free(tok); + num_toks_--; + tok = next_tok; + } + } + active_toks_.clear(); + KALDI_ASSERT(num_toks_ == 0); +} + +// static +template +void LatticeFasterDecoderTpl::TopSortTokens( + Token *tok_list, std::vector *topsorted_list) { + unordered_map token2pos; + typedef typename unordered_map::iterator IterType; + int32 num_toks = 0; + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + num_toks++; + int32 cur_pos = 0; + // We assign the tokens numbers num_toks - 1, ... , 2, 1, 0. + // This is likely to be in closer to topological order than + // if we had given them ascending order, because of the way + // new tokens are put at the front of the list. + for (Token *tok = tok_list; tok != NULL; tok = tok->next) + token2pos[tok] = num_toks - ++cur_pos; + + unordered_set reprocess; + + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) { + Token *tok = iter->first; + int32 pos = iter->second; + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + // We only need to consider epsilon links, since non-epsilon links + // transition between frames and this function only needs to sort a list + // of tokens from a single frame. + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { // another token on this frame, + // so must consider it. + int32 next_pos = following_iter->second; + if (next_pos < pos) { // reassign the position of the next Token. + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + // In case we had previously assigned this token to be reprocessed, we can + // erase it from that set because it's "happy now" (we just processed it). + reprocess.erase(tok); + } + + size_t max_loop = 1000000, loop_count; // max_loop is to detect epsilon cycles. + for (loop_count = 0; + !reprocess.empty() && loop_count < max_loop; ++loop_count) { + std::vector reprocess_vec; + for (typename unordered_set::iterator iter = reprocess.begin(); + iter != reprocess.end(); ++iter) + reprocess_vec.push_back(*iter); + reprocess.clear(); + for (typename std::vector::iterator iter = reprocess_vec.begin(); + iter != reprocess_vec.end(); ++iter) { + Token *tok = *iter; + int32 pos = token2pos[tok]; + // Repeat the processing we did above (for comments, see above). + for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel == 0) { + IterType following_iter = token2pos.find(link->next_tok); + if (following_iter != token2pos.end()) { + int32 next_pos = following_iter->second; + if (next_pos < pos) { + following_iter->second = cur_pos++; + reprocess.insert(link->next_tok); + } + } + } + } + } + } + KALDI_ASSERT(loop_count < max_loop && "Epsilon loops exist in your decoding " + "graph (this is not allowed!)"); + + topsorted_list->clear(); + topsorted_list->resize(cur_pos, NULL); // create a list with NULLs in between. + for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) + (*topsorted_list)[iter->second] = iter->first; +} + +// Instantiate the template for the combination of token types and FST types +// that we'll need. +template class LatticeFasterDecoderTpl, decoder::StdToken>; +template class LatticeFasterDecoderTpl, decoder::StdToken >; +template class LatticeFasterDecoderTpl, decoder::StdToken >; + +template class LatticeFasterDecoderTpl; +template class LatticeFasterDecoderTpl; + +template class LatticeFasterDecoderTpl , decoder::BackpointerToken>; +template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; +template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; +template class LatticeFasterDecoderTpl; +template class LatticeFasterDecoderTpl; + + +} // end namespace kaldi. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h new file mode 100644 index 00000000000..2016ad57115 --- /dev/null +++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h @@ -0,0 +1,549 @@ +// decoder/lattice-faster-decoder.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_ +#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_ + +#include "decoder/grammar-fst.h" +#include "fst/fstlib.h" +#include "fst/memory.h" +#include "fstext/fstext-lib.h" +#include "itf/decodable-itf.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "util/hash-list.h" +#include "util/stl-utils.h" + +namespace kaldi { + +struct LatticeFasterDecoderConfig { + BaseFloat beam; + int32 max_active; + int32 min_active; + BaseFloat lattice_beam; + int32 prune_interval; + bool determinize_lattice; // not inspected by this class... used in + // command-line program. + BaseFloat beam_delta; + BaseFloat hash_ratio; + // Note: we don't make prune_scale configurable on the command line, it's not + // a very important parameter. It affects the algorithm that prunes the + // tokens as we go. + BaseFloat prune_scale; + + // Number of elements in the block for Token and ForwardLink memory + // pool allocation. + int32 memory_pool_tokens_block_size; + int32 memory_pool_links_block_size; + + // Most of the options inside det_opts are not actually queried by the + // LatticeFasterDecoder class itself, but by the code that calls it, for + // example in the function DecodeUtteranceLatticeFaster. + fst::DeterminizeLatticePhonePrunedOptions det_opts; + + LatticeFasterDecoderConfig() + : beam(16.0), + max_active(std::numeric_limits::max()), + min_active(200), + lattice_beam(10.0), + prune_interval(25), + determinize_lattice(true), + beam_delta(0.5), + hash_ratio(2.0), + prune_scale(0.1), + memory_pool_tokens_block_size(1 << 8), + memory_pool_links_block_size(1 << 8) {} + void Register(OptionsItf *opts) { + det_opts.Register(opts); + opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); + opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " + "more accurate"); + opts->Register("min-active", &min_active, "Decoder minimum #active states."); + opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " + "and deeper lattices"); + opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " + "which to prune tokens"); + opts->Register("determinize-lattice", &determinize_lattice, "If true, " + "determinize the lattice (lattice-determinization, keeping only " + "best pdf-sequence for each word-sequence)."); + opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " + "parameter is obscure and relates to a speedup in the way the " + "max-active constraint is applied. Larger is more accurate."); + opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " + "control hash behavior"); + opts->Register("memory-pool-tokens-block-size", &memory_pool_tokens_block_size, + "Memory pool block size suggestion for storing tokens (in elements). " + "Smaller uses less memory but increases cache misses."); + opts->Register("memory-pool-links-block-size", &memory_pool_links_block_size, + "Memory pool block size suggestion for storing links (in elements). " + "Smaller uses less memory but increases cache misses."); + } + void Check() const { + KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 + && min_active <= max_active + && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 + && prune_scale > 0.0 && prune_scale < 1.0); + } +}; + +namespace decoder { +// We will template the decoder on the token type as well as the FST type; this +// is a mechanism so that we can use the same underlying decoder code for +// versions of the decoder that support quickly getting the best path +// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also +// those that do not (LatticeFasterDecoder). + + +// ForwardLinks are the links from a token to a token on the next frame. +// or sometimes on the current frame (for input-epsilon links). +template +struct ForwardLink { + using Label = fst::StdArc::Label; + + Token *next_tok; // the next token [or NULL if represents final-state] + Label ilabel; // ilabel on arc + Label olabel; // olabel on arc + BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.) + BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc + ForwardLink *next; // next in singly-linked list of forward arcs (arcs + // in the state-level lattice) from a token. + inline ForwardLink(Token *next_tok, Label ilabel, Label olabel, + BaseFloat graph_cost, BaseFloat acoustic_cost, + ForwardLink *next): + next_tok(next_tok), ilabel(ilabel), olabel(olabel), + graph_cost(graph_cost), acoustic_cost(acoustic_cost), + next(next) { } +}; + + +struct StdToken { + using ForwardLinkT = ForwardLink; + using Token = StdToken; + + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals the + // minimum difference between the cost of the best path that this link is a + // part of, and the cost of the absolute best path, under the assumption that + // any of the currently active states at the decoding front may eventually + // succeed (e.g. if you were to take the currently active states one by one + // and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + Token *next; + + // This function does nothing and should be optimized out; it's needed + // so we can share the regular LatticeFasterDecoderTpl code and the code + // for LatticeFasterOnlineDecoder that supports fast traceback. + inline void SetBackpointer (Token *backpointer) { } + + // This constructor just ignores the 'backpointer' argument. That argument is + // needed so that we can use the same decoder code for LatticeFasterDecoderTpl + // and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a + // fast way to obtain the best path). + inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { } +}; + +struct BackpointerToken { + using ForwardLinkT = ForwardLink; + using Token = BackpointerToken; + + // BackpointerToken is like Token but also + // Standard token type for LatticeFasterDecoder. Each active HCLG + // (decoding-graph) state on each frame has one token. + + // tot_cost is the total (LM + acoustic) cost from the beginning of the + // utterance up to this point. (but see cost_offset_, which is subtracted + // to keep it in a good numerical range). + BaseFloat tot_cost; + + // exta_cost is >= 0. After calling PruneForwardLinks, this equals + // the minimum difference between the cost of the best path, and the cost of + // this is on, and the cost of the absolute best path, under the assumption + // that any of the currently active states at the decoding front may + // eventually succeed (e.g. if you were to take the currently active states + // one by one and compute this difference, and then take the minimum). + BaseFloat extra_cost; + + // 'links' is the head of singly-linked list of ForwardLinks, which is what we + // use for lattice generation. + ForwardLinkT *links; + + //'next' is the next in the singly-linked list of tokens for this frame. + BackpointerToken *next; + + // Best preceding BackpointerToken (could be a on this frame, connected to + // this via an epsilon transition, or on a previous frame). This is only + // required for an efficient GetBestPath function in + // LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation + // (the "links" list is what stores the forward links, for that). + Token *backpointer; + + inline void SetBackpointer (Token *backpointer) { + this->backpointer = backpointer; + } + + inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links, + Token *next, Token *backpointer): + tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), + backpointer(backpointer) { } +}; + +} // namespace decoder + + +/** This is the "normal" lattice-generating decoder. + See \ref lattices_generation \ref decoders_faster and \ref decoders_simple + for more information. + + The decoder is templated on the FST type and the token type. The token type + will normally be StdToken, but also may be BackpointerToken which is to support + quick lookup of the current best path (see lattice-faster-online-decoder.h) + + The FST you invoke this decoder which is expected to equal + Fst::Fst, a.k.a. StdFst, or GrammarFst. If you invoke it with + FST == StdFst and it notices that the actual FST type is + fst::VectorFst or fst::ConstFst, the decoder object + will internally cast itself to one that is templated on those more specific + types; this is an optimization for speed. + */ +template +class LatticeFasterDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterDecoderTpl(const FST &fst, + const LatticeFasterDecoderConfig &config); + + // This version of the constructor takes ownership of the fst, and will delete + // it when this object is destroyed. + LatticeFasterDecoderTpl(const LatticeFasterDecoderConfig &config, + FST *fst); + + void SetOptions(const LatticeFasterDecoderConfig &config) { + config_ = config; + } + + const LatticeFasterDecoderConfig &GetOptions() const { + return config_; + } + + ~LatticeFasterDecoderTpl(); + + /// Decodes until there are no more frames left in the "decodable" object.. + /// note, this may block waiting for input if the "decodable" object blocks. + /// Returns true if any kind of traceback is available (not necessarily from a + /// final state). + bool Decode(DecodableInterface *decodable); + + + /// says whether a final-state was active on the last frame. If it was not, the + /// lattice (or traceback) will end with states that are not final-states. + bool ReachedFinal() const { + return FinalRelativeCost() != std::numeric_limits::infinity(); + } + + /// Outputs an FST corresponding to the single best path through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. Note: this just calls GetRawLattice() + /// and figures out the shortest path. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true) const; + + /// Outputs an FST corresponding to the raw, state-level + /// tracebacks. Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state + /// of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + /// The raw lattice will be topologically sorted. + /// + /// See also GetRawLatticePruned in lattice-faster-online-decoder.h, + /// which also supports a pruning beam, in case for some reason + /// you want it pruned tighter than the regular lattice beam. + /// We could put that here in future needed. + bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const; + + + + /// [Deprecated, users should now use GetRawLattice and determinize it + /// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper]. + /// Outputs an FST corresponding to the lattice-determinized + /// lattice (one path per word sequence). Returns true if result is nonempty. + /// If "use_final_probs" is true AND we reached the final-state of the graph + /// then it will include those as final-probs, else it will treat all + /// final-probs as one. + bool GetLattice(CompactLattice *ofst, + bool use_final_probs = true) const; + + /// InitDecoding initializes the decoding, and should only be used if you + /// intend to call AdvanceDecoding(). If you call Decode(), you don't need to + /// call this. You can also call InitDecoding if you have already decoded an + /// utterance and want to start with a new utterance. + void InitDecoding(); + + /// This will decode until there are no more frames ready in the decodable + /// object. You can keep calling it each time more frames become available. + /// If max_num_frames is specified, it specifies the maximum number of frames + /// the function will decode before returning. + void AdvanceDecoding(DecodableInterface *decodable, + int32 max_num_frames = -1); + + /// This function may be optionally called after AdvanceDecoding(), when you + /// do not plan to decode any further. It does an extra pruning step that + /// will help to prune the lattices output by GetLattice and (particularly) + /// GetRawLattice more completely, particularly toward the end of the + /// utterance. If you call this, you cannot call AdvanceDecoding again (it + /// will fail), and you cannot call GetLattice() and related functions with + /// use_final_probs = false. Used to be called PruneActiveTokensFinal(). + void FinalizeDecoding(); + + /// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives + /// more information. It returns the difference between the best (final-cost + /// plus cost) of any token on the final frame, and the best cost of any token + /// on the final frame. If it is infinity it means no final-states were + /// present on the final frame. It will usually be nonnegative. If it not + /// too positive (e.g. < 5 is my first guess, but this is not tested) you can + /// take it as a good indication that we reached the final-state with + /// reasonable likelihood. + BaseFloat FinalRelativeCost() const; + + + // Returns the number of frames decoded so far. The value returned changes + // whenever we call ProcessEmitting(). + inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } + + protected: + // we make things protected instead of private, as code in + // LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the + // internals. + + // Deletes the elements of the singly linked list tok->links. + void DeleteForwardLinks(Token *tok); + + // head of per-frame list of Tokens (list is in topological order), + // and something saying whether we ever pruned it using PruneForwardLinks. + struct TokenList { + Token *toks; + bool must_prune_forward_links; + bool must_prune_tokens; + TokenList(): toks(NULL), must_prune_forward_links(true), + must_prune_tokens(true) { } + }; + + using Elem = typename HashList::Elem; + // Equivalent to: + // struct Elem { + // StateId key; + // Token *val; + // Elem *tail; + // }; + + void PossiblyResizeHash(size_t num_toks); + + // FindOrAddToken either locates a token in hash of toks_, or if necessary + // inserts a new, empty token (i.e. with no forward links) for the current + // frame. [note: it's inserted if necessary into hash toks_ and also into the + // singly linked list of tokens active on this frame (whose head is at + // active_toks_[frame]). The frame_plus_one argument is the acoustic frame + // index plus one, which is used to index into the active_toks_ array. + // Returns the Token pointer. Sets "changed" (if non-NULL) to true if the + // token was newly created or the cost changed. + // If Token == StdToken, the 'backpointer' argument has no purpose (and will + // hopefully be optimized out). + inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one, + BaseFloat tot_cost, Token *backpointer, + bool *changed); + + // prunes outgoing links for all tokens in active_toks_[frame] + // it's called by PruneActiveTokens + // all links, that have link_extra_cost > lattice_beam are pruned + // delta is the amount by which the extra_costs must change + // before we set *extra_costs_changed = true. + // If delta is larger, we'll tend to go back less far + // toward the beginning of the file. + // extra_costs_changed is set to true if extra_cost was changed for any token + // links_pruned is set to true if any link in any token was pruned + void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed, + bool *links_pruned, + BaseFloat delta); + + // This function computes the final-costs for tokens active on the final + // frame. It outputs to final-costs, if non-NULL, a map from the Token* + // pointer to the final-prob of the corresponding state, for all Tokens + // that correspond to states that have final-probs. This map will be + // empty if there were no final-probs. It outputs to + // final_relative_cost, if non-NULL, the difference between the best + // forward-cost including the final-prob cost, and the best forward-cost + // without including the final-prob cost (this will usually be positive), or + // infinity if there were no final-probs. [c.f. FinalRelativeCost(), which + // outputs this quanitity]. It outputs to final_best_cost, if + // non-NULL, the lowest for any token t active on the final frame, of + // forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in + // the graph of the state corresponding to token t, or the best of + // forward-cost[t] if there were no final-probs active on the final frame. + // You cannot call this after FinalizeDecoding() has been called; in that + // case you should get the answer from class-member variables. + void ComputeFinalCosts(unordered_map *final_costs, + BaseFloat *final_relative_cost, + BaseFloat *final_best_cost) const; + + // PruneForwardLinksFinal is a version of PruneForwardLinks that we call + // on the final frame. If there are final tokens active, it uses + // the final-probs for pruning, otherwise it treats all tokens as final. + void PruneForwardLinksFinal(); + + // Prune away any tokens on this frame that have no forward links. + // [we don't do this in PruneForwardLinks because it would give us + // a problem with dangling pointers]. + // It's called by PruneActiveTokens if any forward links have been pruned + void PruneTokensForFrame(int32 frame_plus_one); + + + // Go backwards through still-alive tokens, pruning them if the + // forward+backward cost is more than lat_beam away from the best path. It's + // possible to prove that this is "correct" in the sense that we won't lose + // anything outside of lat_beam, regardless of what happens in the future. + // delta controls when it considers a cost to have changed enough to continue + // going backward and propagating the change. larger delta -> will recurse + // less far. + void PruneActiveTokens(BaseFloat delta); + + /// Gets the weight cutoff. Also counts the active tokens. + BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, + BaseFloat *adaptive_beam, Elem **best_elem); + + /// Processes emitting arcs for one frame. Propagates from prev_toks_ to + /// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to + /// use. + BaseFloat ProcessEmitting(DecodableInterface *decodable); + + /// Processes nonemitting (epsilon) arcs for one frame. Called after + /// ProcessEmitting() on each frame. The cost cutoff is computed by the + /// preceding ProcessEmitting(). + void ProcessNonemitting(BaseFloat cost_cutoff); + + // HashList defined in ../util/hash-list.h. It actually allows us to maintain + // more than one list (e.g. for current and previous frames), but only one of + // them at a time can be indexed by StateId. It is indexed by frame-index + // plus one, where the frame-index is zero-based, as used in decodable object. + // That is, the emitting probs of frame t are accounted for in tokens at + // toks_[t+1]. The zeroth frame is for nonemitting transition at the start of + // the graph. + HashList toks_; + + std::vector active_toks_; // Lists of tokens, indexed by + // frame (members of TokenList are toks, must_prune_forward_links, + // must_prune_tokens). + std::vector queue_; // temp variable used in ProcessNonemitting, + std::vector tmp_array_; // used in GetCutoff. + + // fst_ is a pointer to the FST we are decoding from. + const FST *fst_; + // delete_fst_ is true if the pointer fst_ needs to be deleted when this + // object is destroyed. + bool delete_fst_; + + std::vector cost_offsets_; // This contains, for each + // frame, an offset that was added to the acoustic log-likelihoods on that + // frame in order to keep everything in a nice dynamic range i.e. close to + // zero, to reduce roundoff errors. + LatticeFasterDecoderConfig config_; + int32 num_toks_; // current total #toks allocated... + bool warned_; + + /// decoding_finalized_ is true if someone called FinalizeDecoding(). [note, + /// calling this is optional]. If true, it's forbidden to decode more. Also, + /// if this is set, then the output of ComputeFinalCosts() is in the next + /// three variables. The reason we need to do this is that after + /// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some + /// of the tokens on the last frame are freed, so we free the list from toks_ + /// to avoid having dangling pointers hanging around. + bool decoding_finalized_; + /// For the meaning of the next 3 variables, see the comment for + /// decoding_finalized_ above., and ComputeFinalCosts(). + unordered_map final_costs_; + BaseFloat final_relative_cost_; + BaseFloat final_best_cost_; + + // Memory pools for storing tokens and forward links. + // We use it to decrease the work put on allocator and to move some of data + // together. Too small block sizes will result in more work to allocator but + // bigger ones increase the memory usage. + fst::MemoryPool token_pool_; + fst::MemoryPool forward_link_pool_; + + // There are various cleanup tasks... the toks_ structure contains + // singly linked lists of Token pointers, where Elem is the list type. + // It also indexes them in a hash, indexed by state (this hash is only + // maintained for the most recent frame). toks_.Clear() + // deletes them from the hash and returns the list of Elems. The + // function DeleteElems calls toks_.Delete(elem) for each elem in + // the list, which returns ownership of the Elem to the toks_ structure + // for reuse, but does not delete the Token pointer. The Token pointers + // are reference-counted and are ultimately deleted in PruneTokensForFrame, + // but are also linked together on each frame by their own linked-list, + // using the "next" pointer. We delete them manually. + void DeleteElems(Elem *list); + + // This function takes a singly linked list of tokens for a single frame, and + // outputs a list of them in topological order (it will crash if no such order + // can be found, which will typically be due to decoding graphs with epsilon + // cycles, which are not allowed). Note: the output list may contain NULLs, + // which the caller should pass over; it just happens to be more efficient for + // the algorithm to output a list that contains NULLs. + static void TopSortTokens(Token *tok_list, + std::vector *topsorted_list); + + void ClearActiveTokens(); + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderTpl); +}; + +typedef LatticeFasterDecoderTpl LatticeFasterDecoder; + + + +} // end namespace kaldi. + +#endif diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc new file mode 100644 index 00000000000..ebdace7e849 --- /dev/null +++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc @@ -0,0 +1,285 @@ +// decoder/lattice-faster-online-decoder.cc + +// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann +// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.cc, about how to maintain this +// file in sync with lattice-faster-decoder.cc + +#include "decoder/lattice-faster-online-decoder.h" +#include "lat/lattice-functions.h" + +namespace kaldi { + +template +bool LatticeFasterOnlineDecoderTpl::TestGetBestPath( + bool use_final_probs) const { + Lattice lat1; + { + Lattice raw_lat; + this->GetRawLattice(&raw_lat, use_final_probs); + ShortestPath(raw_lat, &lat1); + } + Lattice lat2; + GetBestPath(&lat2, use_final_probs); + BaseFloat delta = 0.1; + int32 num_paths = 1; + if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) { + KALDI_WARN << "Best-path test failed"; + return false; + } else { + return true; + } +} + + +// Outputs an FST corresponding to the single best path through the lattice. +template +bool LatticeFasterOnlineDecoderTpl::GetBestPath(Lattice *olat, + bool use_final_probs) const { + olat->DeleteStates(); + BaseFloat final_graph_cost; + BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost); + if (iter.Done()) + return false; // would have printed warning. + StateId state = olat->AddState(); + olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0)); + while (!iter.Done()) { + LatticeArc arc; + iter = TraceBackBestPath(iter, &arc); + arc.nextstate = state; + StateId new_state = olat->AddState(); + olat->AddArc(new_state, arc); + state = new_state; + } + olat->SetStart(state); + return true; +} + +template +typename LatticeFasterOnlineDecoderTpl::BestPathIterator LatticeFasterOnlineDecoderTpl::BestPathEnd( + bool use_final_probs, + BaseFloat *final_cost_out) const { + if (this->decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "BestPathEnd() with use_final_probs == false"; + KALDI_ASSERT(this->NumFramesDecoded() > 0 && + "You cannot call BestPathEnd if no frames were decoded."); + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (this->decoding_finalized_ ? this->final_costs_ :final_costs_local); + if (!this->decoding_finalized_ && use_final_probs) + this->ComputeFinalCosts(&final_costs_local, NULL, NULL); + + // Singly linked list of tokens on last frame (access list through "next" + // pointer). + BaseFloat best_cost = std::numeric_limits::infinity(); + BaseFloat best_final_cost = 0; + Token *best_tok = NULL; + for (Token *tok = this->active_toks_.back().toks; + tok != NULL; tok = tok->next) { + BaseFloat cost = tok->tot_cost, final_cost = 0.0; + if (use_final_probs && !final_costs.empty()) { + // if we are instructed to use final-probs, and any final tokens were + // active on final frame, include the final-prob in the cost of the token. + typename unordered_map::const_iterator + iter = final_costs.find(tok); + if (iter != final_costs.end()) { + final_cost = iter->second; + cost += final_cost; + } else { + cost = std::numeric_limits::infinity(); + } + } + if (cost < best_cost) { + best_cost = cost; + best_tok = tok; + best_final_cost = final_cost; + } + } + if (best_tok == NULL) { // this should not happen, and is likely a code error or + // caused by infinities in likelihoods, but I'm not making + // it a fatal error for now. + KALDI_WARN << "No final token found."; + } + if (final_cost_out) + *final_cost_out = best_final_cost; + return BestPathIterator(best_tok, this->NumFramesDecoded() - 1); +} + + +template +typename LatticeFasterOnlineDecoderTpl::BestPathIterator LatticeFasterOnlineDecoderTpl::TraceBackBestPath( + BestPathIterator iter, LatticeArc *oarc) const { + KALDI_ASSERT(!iter.Done() && oarc != NULL); + Token *tok = static_cast(iter.tok); + int32 cur_t = iter.frame, step_t = 0; + if (tok->backpointer != NULL) { + // retrieve the correct forward link(with the best link cost) + BaseFloat best_cost = std::numeric_limits::infinity(); + ForwardLinkT *link; + for (link = tok->backpointer->links; + link != NULL; link = link->next) { + if (link->next_tok == tok) { // this is a link to "tok" + BaseFloat graph_cost = link->graph_cost, + acoustic_cost = link->acoustic_cost; + BaseFloat cost = graph_cost + acoustic_cost; + if (cost < best_cost) { + oarc->ilabel = link->ilabel; + oarc->olabel = link->olabel; + if (link->ilabel != 0) { + KALDI_ASSERT(static_cast(cur_t) < this->cost_offsets_.size()); + acoustic_cost -= this->cost_offsets_[cur_t]; + step_t = -1; + } else { + step_t = 0; + } + oarc->weight = LatticeWeight(graph_cost, acoustic_cost); + best_cost = cost; + } + } + } + if (link == NULL && + best_cost == std::numeric_limits::infinity()) { // Did not find correct link. + KALDI_ERR << "Error tracing best-path back (likely " + << "bug in token-pruning algorithm)"; + } + } else { + oarc->ilabel = 0; + oarc->olabel = 0; + oarc->weight = LatticeWeight::One(); // zero costs. + } + return BestPathIterator(tok->backpointer, cur_t + step_t); +} + +template +bool LatticeFasterOnlineDecoderTpl::GetRawLatticePruned( + Lattice *ofst, + bool use_final_probs, + BaseFloat beam) const { + typedef LatticeArc Arc; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + typedef Arc::Label Label; + + // Note: you can't use the old interface (Decode()) if you want to + // get the lattice with use_final_probs = false. You'd have to do + // InitDecoding() and then AdvanceDecoding(). + if (this->decoding_finalized_ && !use_final_probs) + KALDI_ERR << "You cannot call FinalizeDecoding() and then call " + << "GetRawLattice() with use_final_probs == false"; + + unordered_map final_costs_local; + + const unordered_map &final_costs = + (this->decoding_finalized_ ? this->final_costs_ : final_costs_local); + if (!this->decoding_finalized_ && use_final_probs) + this->ComputeFinalCosts(&final_costs_local, NULL, NULL); + + ofst->DeleteStates(); + // num-frames plus one (since frames are one-based, and we have + // an extra frame for the start-state). + int32 num_frames = this->active_toks_.size() - 1; + KALDI_ASSERT(num_frames > 0); + for (int32 f = 0; f <= num_frames; f++) { + if (this->active_toks_[f].toks == NULL) { + KALDI_WARN << "No tokens active on frame " << f + << ": not producing lattice.\n"; + return false; + } + } + unordered_map tok_map; + std::queue > tok_queue; + // First initialize the queue and states. Put the initial state on the queue; + // this is the last token in the list active_toks_[0].toks. + for (Token *tok = this->active_toks_[0].toks; + tok != NULL; tok = tok->next) { + if (tok->next == NULL) { + tok_map[tok] = ofst->AddState(); + ofst->SetStart(tok_map[tok]); + std::pair tok_pair(tok, 0); // #frame = 0 + tok_queue.push(tok_pair); + } + } + + // Next create states for "good" tokens + while (!tok_queue.empty()) { + std::pair cur_tok_pair = tok_queue.front(); + tok_queue.pop(); + Token *cur_tok = cur_tok_pair.first; + int32 cur_frame = cur_tok_pair.second; + KALDI_ASSERT(cur_frame >= 0 && + cur_frame <= this->cost_offsets_.size()); + + typename unordered_map::const_iterator iter = + tok_map.find(cur_tok); + KALDI_ASSERT(iter != tok_map.end()); + StateId cur_state = iter->second; + + for (ForwardLinkT *l = cur_tok->links; + l != NULL; + l = l->next) { + Token *next_tok = l->next_tok; + if (next_tok->extra_cost < beam) { + // so both the current and the next token are good; create the arc + int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1; + StateId nextstate; + if (tok_map.find(next_tok) == tok_map.end()) { + nextstate = tok_map[next_tok] = ofst->AddState(); + tok_queue.push(std::pair(next_tok, next_frame)); + } else { + nextstate = tok_map[next_tok]; + } + BaseFloat cost_offset = (l->ilabel != 0 ? + this->cost_offsets_[cur_frame] : 0); + Arc arc(l->ilabel, l->olabel, + Weight(l->graph_cost, l->acoustic_cost - cost_offset), + nextstate); + ofst->AddArc(cur_state, arc); + } + } + if (cur_frame == num_frames) { + if (use_final_probs && !final_costs.empty()) { + typename unordered_map::const_iterator iter = + final_costs.find(cur_tok); + if (iter != final_costs.end()) + ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0)); + } else { + ofst->SetFinal(cur_state, LatticeWeight::One()); + } + } + } + return (ofst->NumStates() != 0); +} + + + +// Instantiate the template for the FST types that we'll need. +template class LatticeFasterOnlineDecoderTpl >; +template class LatticeFasterOnlineDecoderTpl >; +template class LatticeFasterOnlineDecoderTpl >; +template class LatticeFasterOnlineDecoderTpl; +template class LatticeFasterOnlineDecoderTpl; + + +} // end namespace kaldi. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h new file mode 100644 index 00000000000..8b10996fd0b --- /dev/null +++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h @@ -0,0 +1,147 @@ +// decoder/lattice-faster-online-decoder.h + +// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann; +// 2013-2014 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen +// 2018 Zhehuai Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// see note at the top of lattice-faster-decoder.h, about how to maintain this +// file in sync with lattice-faster-decoder.h + + +#ifndef KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_ +#define KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_ + +#include "util/stl-utils.h" +#include "util/hash-list.h" +#include "fst/fstlib.h" +#include "itf/decodable-itf.h" +#include "fstext/fstext-lib.h" +#include "lat/determinize-lattice-pruned.h" +#include "lat/kaldi-lattice.h" +#include "decoder/lattice-faster-decoder.h" + +namespace kaldi { + + + +/** LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also + supports an efficient way to get the best path (see the function + BestPathEnd()), which is useful in endpointing and in situations where you + might want to frequently access the best path. + + This is only templated on the FST type, since the Token type is required to + be BackpointerToken. Actually it only makes sense to instantiate + LatticeFasterDecoderTpl with Token == BackpointerToken if you do so indirectly via + this child class. + */ +template +class LatticeFasterOnlineDecoderTpl: + public LatticeFasterDecoderTpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using Token = decoder::BackpointerToken; + using ForwardLinkT = decoder::ForwardLink; + + // Instantiate this class once for each thing you have to decode. + // This version of the constructor does not take ownership of + // 'fst'. + LatticeFasterOnlineDecoderTpl(const FST &fst, + const LatticeFasterDecoderConfig &config): + LatticeFasterDecoderTpl(fst, config) { } + + // This version of the initializer takes ownership of 'fst', and will delete + // it when this object is destroyed. + LatticeFasterOnlineDecoderTpl(const LatticeFasterDecoderConfig &config, + FST *fst): + LatticeFasterDecoderTpl(config, fst) { } + + + struct BestPathIterator { + void *tok; + int32 frame; + // note, "frame" is the frame-index of the frame you'll get the + // transition-id for next time, if you call TraceBackBestPath on this + // iterator (assuming it's not an epsilon transition). Note that this + // is one less than you might reasonably expect, e.g. it's -1 for + // the nonemitting transitions before the first frame. + BestPathIterator(void *t, int32 f): tok(t), frame(f) { } + bool Done() const { return tok == NULL; } + }; + + + /// Outputs an FST corresponding to the single best path through the lattice. + /// This is quite efficient because it doesn't get the entire raw lattice and find + /// the best path through it; instead, it uses the BestPathEnd and BestPathIterator + /// so it basically traces it back through the lattice. + /// Returns true if result is nonempty (using the return status is deprecated, + /// it will become void). If "use_final_probs" is true AND we reached the + /// final-state of the graph then it will include those as final-probs, else + /// it will treat all final-probs as one. + bool GetBestPath(Lattice *ofst, + bool use_final_probs = true) const; + + + /// This function does a self-test of GetBestPath(). Returns true on + /// success; returns false and prints a warning on failure. + bool TestGetBestPath(bool use_final_probs = true) const; + + + /// This function returns an iterator that can be used to trace back + /// the best path. If use_final_probs == true and at least one final state + /// survived till the end, it will use the final-probs in working out the best + /// final Token, and will output the final cost to *final_cost (if non-NULL), + /// else it will use only the forward likelihood, and will put zero in + /// *final_cost (if non-NULL). + /// Requires that NumFramesDecoded() > 0. + BestPathIterator BestPathEnd(bool use_final_probs, + BaseFloat *final_cost = NULL) const; + + + /// This function can be used in conjunction with BestPathEnd() to trace back + /// the best path one link at a time (e.g. this can be useful in endpoint + /// detection). By "link" we mean a link in the graph; not all links cross + /// frame boundaries, but each time you see a nonzero ilabel you can interpret + /// that as a frame. The return value is the updated iterator. It outputs + /// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer, + /// while leaving its "nextstate" variable unchanged. + BestPathIterator TraceBackBestPath( + BestPathIterator iter, LatticeArc *arc) const; + + + /// Behaves the same as GetRawLattice but only processes tokens whose + /// extra_cost is smaller than the best-cost plus the specified beam. + /// It is only worthwhile to call this function if beam is less than + /// the lattice_beam specified in the config; otherwise, it would + /// return essentially the same thing as GetRawLattice, but more slowly. + bool GetRawLatticePruned(Lattice *ofst, + bool use_final_probs, + BaseFloat beam) const; + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoderTpl); +}; + +typedef LatticeFasterOnlineDecoderTpl LatticeFasterOnlineDecoder; + + +} // end namespace kaldi. + +#endif diff --git a/speechx/speechx/kaldi/feat/CMakeLists.txt b/speechx/speechx/kaldi/feat/CMakeLists.txt index 8b914962129..c3a996ffb20 100644 --- a/speechx/speechx/kaldi/feat/CMakeLists.txt +++ b/speechx/speechx/kaldi/feat/CMakeLists.txt @@ -15,5 +15,6 @@ add_library(kaldi-feat-common feature-window.cc resample.cc mel-computations.cc + cmvn.cc ) -target_link_libraries(kaldi-feat-common PUBLIC kaldi-base kaldi-matrix kaldi-util) \ No newline at end of file +target_link_libraries(kaldi-feat-common PUBLIC kaldi-base kaldi-matrix kaldi-util) diff --git a/speechx/speechx/kaldi/feat/cmvn.cc b/speechx/speechx/kaldi/feat/cmvn.cc new file mode 100644 index 00000000000..b2aa46e4fb1 --- /dev/null +++ b/speechx/speechx/kaldi/feat/cmvn.cc @@ -0,0 +1,183 @@ +// transform/cmvn.cc + +// Copyright 2009-2013 Microsoft Corporation +// Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "feat/cmvn.h" + +namespace kaldi { + +void InitCmvnStats(int32 dim, Matrix *stats) { + KALDI_ASSERT(dim > 0); + stats->Resize(2, dim+1); +} + +void AccCmvnStats(const VectorBase &feats, BaseFloat weight, MatrixBase *stats) { + int32 dim = feats.Dim(); + KALDI_ASSERT(stats != NULL); + KALDI_ASSERT(stats->NumRows() == 2 && stats->NumCols() == dim + 1); + // Remove these __restrict__ modifiers if they cause compilation problems. + // It's just an optimization. + double *__restrict__ mean_ptr = stats->RowData(0), + *__restrict__ var_ptr = stats->RowData(1), + *__restrict__ count_ptr = mean_ptr + dim; + const BaseFloat * __restrict__ feats_ptr = feats.Data(); + *count_ptr += weight; + // Careful-- if we change the format of the matrix, the "mean_ptr < count_ptr" + // statement below might become wrong. + for (; mean_ptr < count_ptr; mean_ptr++, var_ptr++, feats_ptr++) { + *mean_ptr += *feats_ptr * weight; + *var_ptr += *feats_ptr * *feats_ptr * weight; + } +} + +void AccCmvnStats(const MatrixBase &feats, + const VectorBase *weights, + MatrixBase *stats) { + int32 num_frames = feats.NumRows(); + if (weights != NULL) { + KALDI_ASSERT(weights->Dim() == num_frames); + } + for (int32 i = 0; i < num_frames; i++) { + SubVector this_frame = feats.Row(i); + BaseFloat weight = (weights == NULL ? 1.0 : (*weights)(i)); + if (weight != 0.0) + AccCmvnStats(this_frame, weight, stats); + } +} + +void ApplyCmvn(const MatrixBase &stats, + bool var_norm, + MatrixBase *feats) { + KALDI_ASSERT(feats != NULL); + int32 dim = stats.NumCols() - 1; + if (stats.NumRows() > 2 || stats.NumRows() < 1 || feats->NumCols() != dim) { + KALDI_ERR << "Dim mismatch: cmvn " + << stats.NumRows() << 'x' << stats.NumCols() + << ", feats " << feats->NumRows() << 'x' << feats->NumCols(); + } + if (stats.NumRows() == 1 && var_norm) + KALDI_ERR << "You requested variance normalization but no variance stats " + << "are supplied."; + + double count = stats(0, dim); + // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when + // computing an offset and representing it as stats, we use a count of one. + if (count < 1.0) + KALDI_ERR << "Insufficient stats for cepstral mean and variance normalization: " + << "count = " << count; + + if (!var_norm) { + Vector offset(dim); + SubVector mean_stats(stats.RowData(0), dim); + offset.AddVec(-1.0 / count, mean_stats); + feats->AddVecToRows(1.0, offset); + return; + } + // norm(0, d) = mean offset; + // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). + Matrix norm(2, dim); + for (int32 d = 0; d < dim; d++) { + double mean, offset, scale; + mean = stats(0, d)/count; + double var = (stats(1, d)/count) - mean*mean, + floor = 1.0e-20; + if (var < floor) { + KALDI_WARN << "Flooring cepstral variance from " << var << " to " + << floor; + var = floor; + } + scale = 1.0 / sqrt(var); + if (scale != scale || 1/scale == 0.0) + KALDI_ERR << "NaN or infinity in cepstral mean/variance computation"; + offset = -(mean*scale); + norm(0, d) = offset; + norm(1, d) = scale; + } + // Apply the normalization. + feats->MulColsVec(norm.Row(1)); + feats->AddVecToRows(1.0, norm.Row(0)); +} + +void ApplyCmvnReverse(const MatrixBase &stats, + bool var_norm, + MatrixBase *feats) { + KALDI_ASSERT(feats != NULL); + int32 dim = stats.NumCols() - 1; + if (stats.NumRows() > 2 || stats.NumRows() < 1 || feats->NumCols() != dim) { + KALDI_ERR << "Dim mismatch: cmvn " + << stats.NumRows() << 'x' << stats.NumCols() + << ", feats " << feats->NumRows() << 'x' << feats->NumCols(); + } + if (stats.NumRows() == 1 && var_norm) + KALDI_ERR << "You requested variance normalization but no variance stats " + << "are supplied."; + + double count = stats(0, dim); + // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when + // computing an offset and representing it as stats, we use a count of one. + if (count < 1.0) + KALDI_ERR << "Insufficient stats for cepstral mean and variance normalization: " + << "count = " << count; + + Matrix norm(2, dim); // norm(0, d) = mean offset + // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). + for (int32 d = 0; d < dim; d++) { + double mean, offset, scale; + mean = stats(0, d) / count; + if (!var_norm) { + scale = 1.0; + offset = mean; + } else { + double var = (stats(1, d)/count) - mean*mean, + floor = 1.0e-20; + if (var < floor) { + KALDI_WARN << "Flooring cepstral variance from " << var << " to " + << floor; + var = floor; + } + // we aim to transform zero-mean, unit-variance input into data + // with the given mean and variance. + scale = sqrt(var); + offset = mean; + } + norm(0, d) = offset; + norm(1, d) = scale; + } + if (var_norm) + feats->MulColsVec(norm.Row(1)); + feats->AddVecToRows(1.0, norm.Row(0)); +} + + +void FakeStatsForSomeDims(const std::vector &dims, + MatrixBase *stats) { + KALDI_ASSERT(stats->NumRows() == 2 && stats->NumCols() > 1); + int32 dim = stats->NumCols() - 1; + double count = (*stats)(0, dim); + for (size_t i = 0; i < dims.size(); i++) { + int32 d = dims[i]; + KALDI_ASSERT(d >= 0 && d < dim); + (*stats)(0, d) = 0.0; + (*stats)(1, d) = count; + } +} + + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/cmvn.h b/speechx/speechx/kaldi/feat/cmvn.h new file mode 100644 index 00000000000..c6d1b7f74e1 --- /dev/null +++ b/speechx/speechx/kaldi/feat/cmvn.h @@ -0,0 +1,75 @@ +// transform/cmvn.h + +// Copyright 2009-2013 Microsoft Corporation +// Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_TRANSFORM_CMVN_H_ +#define KALDI_TRANSFORM_CMVN_H_ + +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { + +/// This function initializes the matrix to dimension 2 by (dim+1); +/// 1st "dim" elements of 1st row are mean stats, 1st "dim" elements +/// of 2nd row are var stats, last element of 1st row is count, +/// last element of 2nd row is zero. +void InitCmvnStats(int32 dim, Matrix *stats); + +/// Accumulation from a single frame (weighted). +void AccCmvnStats(const VectorBase &feat, + BaseFloat weight, + MatrixBase *stats); + +/// Accumulation from a feature file (possibly weighted-- useful in excluding silence). +void AccCmvnStats(const MatrixBase &feats, + const VectorBase *weights, // or NULL + MatrixBase *stats); + +/// Apply cepstral mean and variance normalization to a matrix of features. +/// If norm_vars == true, expects stats to be of dimension 2 by (dim+1), but +/// if norm_vars == false, will accept stats of dimension 1 by (dim+1); these +/// are produced by the balanced-cmvn code when it computes an offset and +/// represents it as "fake stats". +void ApplyCmvn(const MatrixBase &stats, + bool norm_vars, + MatrixBase *feats); + +/// This is as ApplyCmvn, but does so in the reverse sense, i.e. applies a transform +/// that would take zero-mean, unit-variance input and turn it into output with the +/// stats of "stats". This can be useful if you trained without CMVN but later want +/// to correct a mismatch, so you would first apply CMVN and then do the "reverse" +/// CMVN with the summed stats of your training data. +void ApplyCmvnReverse(const MatrixBase &stats, + bool norm_vars, + MatrixBase *feats); + + +/// Modify the stats so that for some dimensions (specified in "dims"), we +/// replace them with "fake" stats that have zero mean and unit variance; this +/// is done to disable CMVN for those dimensions. +void FakeStatsForSomeDims(const std::vector &dims, + MatrixBase *stats); + + + +} // namespace kaldi + +#endif // KALDI_TRANSFORM_CMVN_H_ diff --git a/speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc b/speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc new file mode 100644 index 00000000000..f6684f0b5b5 --- /dev/null +++ b/speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc @@ -0,0 +1,147 @@ +// lat/determinize-lattice-pruned-test.cc + +// Copyright 2009-2012 Microsoft Corporation +// 2012-2013 Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "lat/determinize-lattice-pruned.h" +#include "fstext/lattice-utils.h" +#include "fstext/fst-test-utils.h" +#include "lat/kaldi-lattice.h" +#include "lat/lattice-functions.h" + +namespace fst { +// Caution: these tests are not as generic as you might think from all the +// templates in the code. They are basically only valid for LatticeArc. +// This is partly due to the fact that certain templates need to be instantiated +// in other .cc files in this directory. + +// test that determinization proceeds correctly on general +// FSTs (not guaranteed determinzable, but we use the +// max-states option to stop it getting out of control). +template void TestDeterminizeLatticePruned() { + typedef kaldi::int32 Int; + typedef typename Arc::Weight Weight; + typedef ArcTpl > CompactArc; + + for(int i = 0; i < 100; i++) { + RandFstOptions opts; + opts.n_states = 4; + opts.n_arcs = 10; + opts.n_final = 2; + opts.allow_empty = false; + opts.weight_multiplier = 0.5; // impt for the randomly generated weights + opts.acyclic = true; + // to be exactly representable in float, + // or this test fails because numerical differences can cause symmetry in + // weights to be broken, which causes the wrong path to be chosen as far + // as the string part is concerned. + + VectorFst *fst = RandPairFst(opts); + + bool sorted = TopSort(fst); + KALDI_ASSERT(sorted); + + ILabelCompare ilabel_comp; + if (kaldi::Rand() % 2 == 0) + ArcSort(fst, ilabel_comp); + + std::cout << "FST before lattice-determinizing is:\n"; + { + FstPrinter fstprinter(*fst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + VectorFst det_fst; + try { + DeterminizeLatticePrunedOptions lat_opts; + lat_opts.max_mem = ((kaldi::Rand() % 2 == 0) ? 100 : 1000); + lat_opts.max_states = ((kaldi::Rand() % 2 == 0) ? -1 : 20); + lat_opts.max_arcs = ((kaldi::Rand() % 2 == 0) ? -1 : 30); + bool ans = DeterminizeLatticePruned(*fst, 10.0, &det_fst, lat_opts); + + std::cout << "FST after lattice-determinizing is:\n"; + { + FstPrinter fstprinter(det_fst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + KALDI_ASSERT(det_fst.Properties(kIDeterministic, true) & kIDeterministic); + // OK, now determinize it a different way and check equivalence. + // [note: it's not normal determinization, it's taking the best path + // for any input-symbol sequence.... + + + VectorFst pruned_fst(*fst); + if (pruned_fst.NumStates() != 0) + kaldi::PruneLattice(10.0, &pruned_fst); + + VectorFst compact_pruned_fst, compact_pruned_det_fst; + ConvertLattice(pruned_fst, &compact_pruned_fst, false); + std::cout << "Compact pruned FST is:\n"; + { + FstPrinter fstprinter(compact_pruned_fst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + ConvertLattice(det_fst, &compact_pruned_det_fst, false); + + std::cout << "Compact version of determinized FST is:\n"; + { + FstPrinter fstprinter(compact_pruned_det_fst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + + if (ans) + KALDI_ASSERT(RandEquivalent(compact_pruned_det_fst, compact_pruned_fst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/)); + } catch (...) { + std::cout << "Failed to lattice-determinize this FST (probably not determinizable)\n"; + } + delete fst; + } +} + +// test that determinization proceeds without crash on acyclic FSTs +// (guaranteed determinizable in this sense). +template void TestDeterminizeLatticePruned2() { + typedef typename Arc::Weight Weight; + RandFstOptions opts; + opts.acyclic = true; + for(int i = 0; i < 100; i++) { + VectorFst *fst = RandPairFst(opts); + std::cout << "FST before lattice-determinizing is:\n"; + { + FstPrinter fstprinter(*fst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + VectorFst ofst; + DeterminizeLatticePruned(*fst, 10.0, &ofst); + std::cout << "FST after lattice-determinizing is:\n"; + { + FstPrinter fstprinter(ofst, NULL, NULL, NULL, false, true, "\t"); + fstprinter.Print(&std::cout, "standard output"); + } + delete fst; + } +} + + +} // end namespace fst + +int main() { + using namespace fst; + TestDeterminizeLatticePruned(); + TestDeterminizeLatticePruned2(); + std::cout << "Tests succeeded\n"; +} diff --git a/speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc b/speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc new file mode 100644 index 00000000000..dbdd9af4645 --- /dev/null +++ b/speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc @@ -0,0 +1,1541 @@ +// lat/determinize-lattice-pruned.cc + +// Copyright 2009-2012 Microsoft Corporation +// 2012-2013 Johns Hopkins University (Author: Daniel Povey) +// 2014 Guoguo Chen + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "fstext/determinize-lattice.h" // for LatticeStringRepository +#include "fstext/fstext-utils.h" +#include "lat/lattice-functions.h" // for PruneLattice +#include "lat/minimize-lattice.h" // for minimization +#include "lat/push-lattice.h" // for minimization +#include "lat/determinize-lattice-pruned.h" + +namespace fst { + +using std::vector; +using std::pair; +using std::greater; + +// class LatticeDeterminizerPruned is templated on the same types that +// CompactLatticeWeight is templated on: the base weight (Weight), typically +// LatticeWeightTpl etc. but could also be e.g. TropicalWeight, and the +// IntType, typically int32, used for the output symbols in the compact +// representation of strings [note: the output symbols would usually be +// p.d.f. id's in the anticipated use of this code] It has a special requirement +// on the Weight type: that there should be a Compare function on the weights +// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 > +// w2. This requires that there be a total order on the weights. + +template class LatticeDeterminizerPruned { + public: + // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 correspondence + // between our states and the states in ofst. If destroy == true, release memory as we go + // (but we cannot output again). + + typedef CompactLatticeWeightTpl CompactWeight; + typedef ArcTpl CompactArc; // arc in compact, acceptor form of lattice + typedef ArcTpl Arc; // arc in non-compact version of lattice + + // Output to standard FST with CompactWeightTpl as its weight type (the + // weight stores the original output-symbol strings). If destroy == true, + // release memory as we go (but we cannot output again). + void Output(MutableFst *ofst, bool destroy = true) { + KALDI_ASSERT(determinized_); + typedef typename Arc::StateId StateId; + StateId nStates = static_cast(output_states_.size()); + if (destroy) + FreeMostMemory(); + ofst->DeleteStates(); + ofst->SetStart(kNoStateId); + if (nStates == 0) { + return; + } + for (StateId s = 0;s < nStates;s++) { + OutputStateId news = ofst->AddState(); + KALDI_ASSERT(news == s); + } + ofst->SetStart(0); + // now process transitions. + for (StateId this_state_id = 0; this_state_id < nStates; this_state_id++) { + OutputState &this_state = *(output_states_[this_state_id]); + vector &this_vec(this_state.arcs); + typename vector::const_iterator iter = this_vec.begin(), end = this_vec.end(); + + for (;iter != end; ++iter) { + const TempArc &temp_arc(*iter); + CompactArc new_arc; + vector