Skip to content

Commit

Permalink
merge the develop
Browse files Browse the repository at this point in the history
  • Loading branch information
ljhzxc committed Apr 18, 2023
2 parents b327368 + bd0d69c commit 7e56299
Show file tree
Hide file tree
Showing 22 changed files with 1,361 additions and 26 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).

### Recent Update
- 🔥 2023.04.06: Add [subtitle file (.srt format) generation example](./demos/streaming_asr_server).
- 🔥 2023.03.14: Add SVS(Singing Voice Synthesis) examples with Opencpop dataset, including [DiffSinger](./examples/opencpop/svs1)[PWGAN](./examples/opencpop/voc1) and [HiFiGAN](./examples/opencpop/voc5), the effect is continuously optimized.
- 👑 2023.03.09: Add [Wav2vec2ASR-zh](./examples/aishell/asr3).
- 🎉 2023.03.07: Add [TTS ARM Linux C++ Demo (with C++ Chinese Text Frontend)](./demos/TTSArmLinux).
Expand Down
1 change: 1 addition & 0 deletions README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
- 🧩 级联模型应用: 作为传统语音任务的扩展,我们结合了自然语言处理、计算机视觉等任务,实现更接近实际需求的产业级应用。

### 近期更新
- 👑 2023.04.06: 新增 [srt格式字幕生成功能](./demos/streaming_asr_server)
- 🔥 2023.03.14: 新增基于 Opencpop 数据集的 SVS (歌唱合成) 示例,包含 [DiffSinger](./examples/opencpop/svs1)[PWGAN](./examples/opencpop/voc1)[HiFiGAN](./examples/opencpop/voc5),效果持续优化中。
- 👑 2023.03.09: 新增 [Wav2vec2ASR-zh](./examples/aishell/asr3)
- 🎉 2023.03.07: 新增 [TTS ARM Linux C++ 部署示例 (包含 C++ 中文文本前端模块)](./demos/TTSArmLinux)
Expand Down
351 changes: 351 additions & 0 deletions demos/streaming_asr_server/README.md

Large diffs are not rendered by default.

351 changes: 351 additions & 0 deletions demos/streaming_asr_server/README_cn.md

Large diffs are not rendered by default.

162 changes: 162 additions & 0 deletions demos/streaming_asr_server/local/websocket_client_srt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/usr/bin/python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# calc avg RTF(NOT Accurate): grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
import argparse
import asyncio
import codecs
import os
from pydub import AudioSegment
import re

from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler

def convert_to_wav(input_file):
# Load audio file
audio = AudioSegment.from_file(input_file)

# Set parameters for audio file
audio = audio.set_channels(1)
audio = audio.set_frame_rate(16000)

# Create output filename
output_file = os.path.splitext(input_file)[0] + ".wav"

# Export audio file as WAV
audio.export(output_file, format="wav")

logger.info(f"{input_file} converted to {output_file}")

def format_time(sec):
# Convert seconds to SRT format (HH:MM:SS,ms)
hours = int(sec/3600)
minutes = int((sec%3600)/60)
seconds = int(sec%60)
milliseconds = int((sec%1)*1000)
return f'{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}'

def results2srt(results, srt_file):
"""convert results from paddlespeech to srt format for subtitle
Args:
results (dict): results from paddlespeech
"""
# times contains start and end time of each word
times = results['times']
# result contains the whole sentence including punctuation
result = results['result']
# split result into several sencences by ',' and '。'
sentences = re.split(',|。', result)[:-1]
# print("sentences: ", sentences)
# generate relative time for each sentence in sentences
relative_times = []
word_i = 0
for sentence in sentences:
relative_times.append([])
for word in sentence:
if relative_times[-1] == []:
relative_times[-1].append(times[word_i]['bg'])
if len(relative_times[-1]) == 1:
relative_times[-1].append(times[word_i]['ed'])
else:
relative_times[-1][1] = times[word_i]['ed']
word_i += 1
# print("relative_times: ", relative_times)
# generate srt file acoording to relative_times and sentences
with open(srt_file, 'w') as f:
for i in range(len(sentences)):
# Write index number
f.write(str(i+1)+'\n')

# Write start and end times
start = format_time(relative_times[i][0])
end = format_time(relative_times[i][1])
f.write(start + ' --> ' + end + '\n')

# Write text
f.write(sentences[i]+'\n\n')
logger.info(f"results saved to {srt_file}")

def main(args):
logger.info("asr websocket client start")
handler = ASRWsAudioHandler(
args.server_ip,
args.port,
endpoint=args.endpoint,
punc_server_ip=args.punc_server_ip,
punc_server_port=args.punc_server_port)
loop = asyncio.get_event_loop()

# check if the wav file is mp3 format
# if so, convert it to wav format using convert_to_wav function
if args.wavfile and os.path.exists(args.wavfile):
if args.wavfile.endswith(".mp3"):
convert_to_wav(args.wavfile)
args.wavfile = args.wavfile.replace(".mp3", ".wav")

# support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logger.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
# result = result["result"]
# logger.info(f"asr websocket client finished : {result}")
results2srt(result, args.wavfile.replace(".wav", ".srt"))

# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logger.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["result"]
w.write(f"{utt_name} {result}\n")


if __name__ == "__main__":
logger.info("Start to do streaming asr client")
parser = argparse.ArgumentParser()
parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
parser.add_argument('--port', type=int, default=8090, help='server port')
parser.add_argument(
'--punc.server_ip',
type=str,
default=None,
dest="punc_server_ip",
help='Punctuation server ip')
parser.add_argument(
'--punc.port',
type=int,
default=8091,
dest="punc_server_port",
help='Punctuation server port')
parser.add_argument(
"--endpoint",
type=str,
default="/paddlespeech/asr/streaming",
help="ASR websocket endpoint")
parser.add_argument(
"--wavfile",
action="store",
help="wav file path ",
default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args()

main(args)
2 changes: 1 addition & 1 deletion examples/aishell/asr0/local/train.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

if [ $# -lt 2 ] && [ $# -gt 3 ];then
if [ $# -lt 2 ] || [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1
fi
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/asr1/local/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ if [ ${seed} != 0 ]; then
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi

if [ $# -lt 2 ] && [ $# -gt 3 ];then
if [ $# -lt 2 ] || [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1
fi
Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/asr3/local/train.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

if [ $# -lt 2 ] && [ $# -gt 3 ];then
if [ $# -lt 2 ] || [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1
fi
Expand Down
19 changes: 15 additions & 4 deletions examples/vctk/vc3/conf/default.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
# 其实没用上,其实用的是 16000
sr: 24000
# 源码 load 的时候用的 24k, 提取 mel 用的 16k, 后续 load 和提取 mel 都要改成 24k
fs: 16000
n_fft: 2048
win_length: 1200
hop_length: 300
n_shift: 300
win_length: 1200 # Window length.(in samples) 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.

fmin: 0 # Minimum frequency of Mel basis.
fmax: 8000 # Maximum frequency of Mel basis. sr // 2
n_mels: 80
# only for StarGANv2 VC
norm: # None here
htk: True
power: 2.0


###########################################################
# MODEL SETTING #
###########################################################
Expand Down
27 changes: 23 additions & 4 deletions examples/vctk/vc3/local/preprocess.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,32 @@ stop_stage=100
config_path=$1

if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/preprocess.py \
--dataset=vctk \
--rootdir=~/datasets/VCTK-Corpus-0.92/ \
--dumpdir=dump \
--config=${config_path} \
--num-cpu=20

fi

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then

fi

if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Normalize ..."
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--speaker-dict=dump/speaker_id_map.txt

python3 ${BIN_DIR}/normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--speaker-dict=dump/speaker_id_map.txt

python3 ${BIN_DIR}/normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--speaker-dict=dump/speaker_id_map.txt

fi
79 changes: 67 additions & 12 deletions paddlespeech/t2s/datasets/am_batch_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,18 +804,73 @@ def jets_multi_spk_batch_fn(examples):
batch["spk_id"] = spk_id
return batch

# 未完成
def starganv2_vc_batch_fn(examples):
batch = {
"x_real": None,
"y_org": None,
"x_ref": None,
"x_ref2": None,
"y_trg": None,
"z_trg": None,
"z_trg2": None,
}
return batch

# 因为要传参数,所以需要额外构建
def build_starganv2_vc_collate_fn(latent_dim: int=16, max_mel_length: int=192):

return StarGANv2VCCollateFn(
latent_dim=latent_dim, max_mel_length=max_mel_length)


class StarGANv2VCCollateFn:
"""Functor class of common_collate_fn()"""

def __init__(self, latent_dim: int=16, max_mel_length: int=192):
self.latent_dim = latent_dim
self.max_mel_length = max_mel_length

def random_clip(self, mel: np.array):
# [80, T]
mel_length = mel.shape[1]
if mel_length > self.max_mel_length:
random_start = np.random.randint(0,
mel_length - self.max_mel_length)
mel = mel[:, random_start:random_start + self.max_mel_length]
return mel

def __call__(self, exmaples):
return self.starganv2_vc_batch_fn(exmaples)

def starganv2_vc_batch_fn(self, examples):
batch_size = len(examples)

label = [np.array(item["label"], dtype=np.int64) for item in examples]
ref_label = [
np.array(item["ref_label"], dtype=np.int64) for item in examples
]

# 需要对 mel 进行裁剪
mel = [self.random_clip(item["mel"]) for item in examples]
ref_mel = [self.random_clip(item["ref_mel"]) for item in examples]
ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples]

mel = batch_sequences(mel)
ref_mel = batch_sequences(ref_mel)
ref_mel_2 = batch_sequences(ref_mel_2)

# convert each batch to paddle.Tensor
# (B,)
label = paddle.to_tensor(label)
ref_label = paddle.to_tensor(ref_label)
# [B, 80, T] -> [B, 1, 80, T]
mel = paddle.to_tensor(mel)
ref_mel = paddle.to_tensor(ref_mel)
ref_mel_2 = paddle.to_tensor(ref_mel_2)

z_trg = paddle.randn(batch_size, self.latent_dim)
z_trg2 = paddle.randn(batch_size, self.latent_dim)

batch = {
"x_real": mels,
"y_org": labels,
"x_ref": ref_mels,
"x_ref2": ref_mels_2,
"y_trg": ref_labels,
"z_trg": z_trg,
"z_trg2": z_trg2
}

return batch


# for PaddleSlim
Expand Down
Loading

0 comments on commit 7e56299

Please sign in to comment.