Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[asr][weboscket]fix the streaming asr server bug, server client #1761

Merged
merged 5 commits into from
Apr 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions demos/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The directory containes many speech applications in multi scenarios.
* punctuation_restoration - restore punctuation from raw text
* speech recogintion - recognize text of an audio file
* speech server - Server for Speech Task, e.g. ASR,TTS,CLS
* streaming asr server - receive audio stream from websocket, and recognize to transcript.
* speech translation - end to end speech translation
* story talker - book reader based on OCR and TTS
* style_fs2 - multi style control for FastSpeech2 model
Expand Down
1 change: 1 addition & 0 deletions demos/README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
* 标点恢复 - 通常作为语音识别的文本后处理任务,为一段无标点的纯文本添加相应的标点符号。
* 语音识别 - 识别一段音频中包含的语音文字。
* 语音服务 - 离线语音服务,包括ASR、TTS、CLS等
* 流式语音识别服务 - 流式输入语音数据流识别音频中的文字
* 语音翻译 - 实时识别音频中的语言,并同时翻译成目标语言。
* 会说话的故事书 - 基于 OCR 和语音合成的会说话的故事书。
* 个性化语音合成 - 基于 FastSpeech2 模型的个性化语音合成。
Expand Down
4 changes: 2 additions & 2 deletions demos/streaming_asr_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
paddlespeech_server start --help
```
Arguments:
- `config_file`: yaml file of the app, defalut: ./conf/ws_conformer_application.yaml
- `log_file`: log file. Default: ./log/paddlespeech.log
- `config_file`: yaml file of the app, defalut: `./conf/application.yaml`
- `log_file`: log file. Default: `./log/paddlespeech.log`

Output:
```bash
Expand Down
4 changes: 2 additions & 2 deletions demos/streaming_asr_server/README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
paddlespeech_server start --help
```
参数:
- `config_file`: 服务的配置文件,默认: ./conf/ws_conformer_application.yaml
- `log_file`: log 文件. 默认:./log/paddlespeech.log
- `config_file`: 服务的配置文件,默认: `./conf/application.yaml`
- `log_file`: log 文件. 默认:`./log/paddlespeech.log`

输出:
```bash
Expand Down
119 changes: 119 additions & 0 deletions paddlespeech/server/utils/audio_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging

import numpy as np
import soundfile
import websockets

from paddlespeech.cli.log import logger


class ASRAudioHandler:
def __init__(self, url="127.0.0.1", port=8090):
"""PaddleSpeech Online ASR Server Client audio handler
Online asr server use the websocket protocal
Args:
url (str, optional): the server ip. Defaults to "127.0.0.1".
port (int, optional): the server port. Defaults to 8090.
"""
self.url = url
self.port = port
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"

def read_wave(self, wavfile_path: str):
"""read the audio file from specific wavfile path

Args:
wavfile_path (str): the audio wavfile,
we assume that audio sample rate matches the model

Yields:
numpy.array: the samall package audio pcm data
"""
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)

chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
padding_len_x = 0

padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0)

assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_size
end = start + chunk_size
x_chunk = padded_x[start:end]
yield x_chunk

async def run(self, wavfile_path: str):
"""Send a audio file to online server

Args:
wavfile_path (str): audio path

Returns:
str: the final asr result
"""
logging.info("send a message to the server")

# 1. send websocket handshake protocal
async with websockets.connect(self.url) as ws:
# 2. server has already received handshake protocal
# client start to send the command
audio_info = json.dumps(
{
"name": "test.wav",
"signal": "start",
"nbest": 5
},
sort_keys=True,
indent=4,
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
logger.info("receive msg={}".format(msg))

# 3. send chunk audio data to engine
for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes())
msg = await ws.recv()
msg = json.loads(msg)
logger.info("receive msg={}".format(msg))

# 4. we must send finished signal to the server
audio_info = json.dumps(
{
"name": "test.wav",
"signal": "end",
"nbest": 5
},
sort_keys=True,
indent=4,
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()

# 5. decode the bytes to str
msg = json.loads(msg)
logger.info("final receive msg={}".format(msg))
result = msg
return result
48 changes: 28 additions & 20 deletions paddlespeech/server/ws/asr_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,50 +20,52 @@

from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio

router = APIRouter()


@router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online ASR Server api

Args:
websocket (WebSocket): the websocket instance
"""

#1. the interface wait to accept the websocket protocal header
# and only we receive the header, it establish the connection with specific thread
await websocket.accept()

#2. if we accept the websocket headers, we will get the online asr engine instance
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
connection_handler = None
# init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer(
window_n=chunk_buffer_conf.window_n,
shift_n=chunk_buffer_conf.shift_n,
window_ms=chunk_buffer_conf.window_ms,
shift_ms=chunk_buffer_conf.shift_ms,
sample_rate=chunk_buffer_conf.sample_rate,
sample_width=chunk_buffer_conf.sample_width)

# init vad
vad_conf = asr_engine.config.get('vad_conf', None)
if vad_conf:
vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
#3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio
# and each connection has its own connection instance to process the request
# and only if client send the start signal, we create the PaddleASRConnectionHanddler instance
connection_handler = None

try:
#4. we do a loop to process the audio package by package according the protocal
# and only if the client send finished signal, we will break the loop
while True:
# careful here, changed the source code from starlette.websockets
# 4.1 we wait for the client signal for the specific action
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)

#4.2 text for the action command and bytes for pcm data
if "text" in message:
# we first parse the specific command
message = json.loads(message["text"])
if 'signal' not in message:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)

# start command, we create the PaddleASRConnectionHanddler instance to process the audio data
# end command, we process the all the last audio pcm and return the final result
# and we break the loop
if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
Expand All @@ -72,6 +74,7 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.send_json(resp)
elif message['signal'] == 'end':
# reset single engine for an new connection
# and we will destroy the connection
connection_handler.decode(is_finished=True)
connection_handler.rescoring()
asr_results = connection_handler.get_result()
Expand All @@ -88,12 +91,17 @@ async def websocket_endpoint(websocket: WebSocket):
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
elif "bytes" in message:
# bytes for the pcm data
message = message["bytes"]

# we extract the remained audio pcm
# and decode for the result in this package data
connection_handler.extract_feat(message)
connection_handler.decode(is_finished=False)
asr_results = connection_handler.get_result()

# return the current period result
# if the engine create the vad instance, this connection will have many period results
resp = {'asr_results': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect:
Expand Down