Skip to content

Commit

Permalink
兼容python3.9及tensorflow2.8等更高版本
Browse files Browse the repository at this point in the history
  • Loading branch information
nl8590687 committed Feb 4, 2023
1 parent ec4b488 commit 5352aff
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 69 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ ASRT是一个基于深度学习的中文语音识别系统,如果您觉得喜
* 硬盘: 500 GB 机械硬盘(或固态硬盘)

### 软件
* Linux: Ubuntu 18.04 + / CentOS 7 +
* Python: 3.7 +
* TensorFlow: 2.5 +
* Linux: Ubuntu 18.04 + / CentOS 7 + 或 Windows 10/11
* Python: 3.7 - 3.10 及后续版本
* TensorFlow: 2.5 - 2.11 及后续版本

## 快速开始

Expand Down Expand Up @@ -146,18 +146,17 @@ Github本仓库下[Releases](https://github.com/nl8590687/ASRT_SpeechRecognition

## Python依赖库

* tensorFlow (2.5+)
* tensorFlow (2.5-2.11+)
* numpy
* wave
* matplotlib
* math
* scipy
* requests
* flask
* waitress
* grpcio / grpcio-tools / protobuf

不会安装环境的同学请直接运行以下命令(前提是有GPU且已经安装好 CUDA 11.2 和 cudnn 8.1):
不会安装环境的同学请直接运行以下命令(前提是有GPU且已经安装好 Python3.9、CUDA 11.2 和 cudnn 8.1):

```shell
$ pip install -r requirements.txt
Expand Down
10 changes: 5 additions & 5 deletions README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ This project uses tensorFlow.keras based on deep convolutional neural network an
* 硬盘: 500 GB HDD(or SSD)

### Software
* Linux: Ubuntu 18.04 + / CentOS 7 +
* Python: 3.7 +
* TensorFlow: 2.5 +
* Linux: Ubuntu 18.04 + / CentOS 7 + or Windows 10/11
* Python: 3.7 - 3.10 and later
* TensorFlow: 2.5 - 2.11 and later

## Quick Start
Take the operation under the Linux system as an example:
Expand Down Expand Up @@ -144,7 +144,7 @@ At present, the best model can basically reach 85% of Pinyin correct rate on the

## Python Dependency Library

* tensorFlow (2.5+)
* tensorFlow (2.5-2.11+)
* numpy
* wave
* matplotlib
Expand All @@ -155,7 +155,7 @@ At present, the best model can basically reach 85% of Pinyin correct rate on the
* waitress
* grpcio / grpcio-tools / protobuf

If you have trouble when install those packages, please run the following script to do it as long as you have a GPU and CUDA 11.2 and cudnn 8.1 have been installed:
If you have trouble when install those packages, please run the following script to do it as long as you have a GPU and python 3.9, CUDA 11.2 and cudnn 8.1 have been installed:

```shell
$ pip install -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_data_count(self) -> int:
"""
return len(self.data_list)

def get_data(self, index:int) -> tuple:
def get_data(self, index: int) -> tuple:
"""
按下标获取一条数据
"""
Expand Down
18 changes: 9 additions & 9 deletions model_zoo/speech_model/keras_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def forward(self, data_input):

in_len[0] = self.output_shape[0]

x_in = np.zeros((batch_size,) + self.input_shape, dtype=np.float)
x_in = np.zeros((batch_size,) + self.input_shape, dtype=np.float64)

for i in range(batch_size):
x_in[i, 0:len(data_input)] = data_input
Expand Down Expand Up @@ -298,10 +298,10 @@ def forward(self, data_input):

in_len[0] = self.output_shape[0]

x_in = np.zeros((batch_size,) + self.input_shape, dtype=np.float)
x_in = np.zeros((batch_size,) + self.input_shape, dtype=np.float64)

for i in range(batch_size):
x_in[i,0:len(data_input)] = data_input
x_in[i, 0:len(data_input)] = data_input

base_pred = self.model_base.predict(x = x_in)
r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
Expand Down Expand Up @@ -398,13 +398,13 @@ def forward(self, data_input):

in_len[0] = self.output_shape[0]

x_in = np.zeros((batch_size,) + self.input_shape, dtype=np.float)
x_in = np.zeros((batch_size,) + self.input_shape, dtype=np.float64)

for i in range(batch_size):
x_in[i,0:len(data_input)] = data_input
x_in[i, 0:len(data_input)] = data_input

base_pred = self.model_base.predict(x = x_in)
r = K.ctc_decode(base_pred, in_len, greedy = True, beam_width=100, top_paths=1)
base_pred = self.model_base.predict(x=x_in)
r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)

if tf.__version__[0:2] == '1.':
r1 = r[0][0].eval(session=tf.compat.v1.Session())
Expand Down Expand Up @@ -469,7 +469,7 @@ def _define_model(self, input_shape, output_size) -> tuple:
layer_h12 = Dense(output_size, use_bias=True, kernel_initializer='he_normal')(layer_h11) # 全连接层
y_pred = Activation('softmax', name='Activation0')(layer_h12)

model_base = Model(inputs = input_data, outputs = y_pred)
model_base = Model(inputs=input_data, outputs=y_pred)
# model_data.summary()

labels = Input(name='the_labels', shape=[label_max_string_length], dtype='float32')
Expand All @@ -492,7 +492,7 @@ def forward(self, data_input):

in_len[0] = self.output_shape[0]

x_in = np.zeros((batch_size,) + self.input_shape, dtype=np.float)
x_in = np.zeros((batch_size,) + self.input_shape, dtype=np.float64)

for i in range(batch_size):
x_in[i, 0:len(data_input)] = data_input
Expand Down
91 changes: 52 additions & 39 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,47 +1,60 @@
absl-py==0.15.0
astor==0.8.0
absl-py==1.4.0
astunparse==1.6.3
cached-property==1.5.2
cachetools==4.2.4
certifi==2019.9.11
charset-normalizer==2.0.7
cycler==0.10.0
flatbuffers==1.12
gast==0.4.0
google-auth==2.3.3
cachetools==5.3.0
certifi==2022.12.7
charset-normalizer==3.0.1
click==8.1.3
colorama==0.4.6
contourpy==1.0.7
cycler==0.11.0
Flask==2.2.2
flatbuffers==23.1.21
fonttools==4.38.0
gast==0.5.3
google-auth==2.16.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.34.1
h5py==3.1.0
idna==3.3
Keras-Applications==1.0.8
keras-nightly==2.5.0.dev2021032900
grpcio==1.51.1
grpcio-tools==1.51.1
h5py==3.8.0
idna==3.4
importlib-metadata==6.0.0
itsdangerous==2.1.2
Jinja2==3.1.2
keras==2.8.0
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.1.1
matplotlib==3.4.2
numpy==1.19.5
oauthlib==3.1.1
kiwisolver==1.4.4
libclang==15.0.6.1
Markdown==3.4.1
MarkupSafe==2.1.2
matplotlib==3.6.3
numpy==1.24.1
oauthlib==3.2.2
opt-einsum==3.3.0
Pillow==8.3.2
protobuf==3.15.0
packaging==23.0
Pillow==9.4.0
protobuf==3.19.6
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==2.4.7
python-dateutil==2.8.1
python-speech-features==0.6
requests==2.26.0
requests-oauthlib==1.3.0
rsa==4.7.2
scipy==1.6.3
six==1.15.0
tensorboard==2.7.0
pyparsing==3.0.9
python-dateutil==2.8.2
requests==2.28.2
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.10.0
six==1.16.0
tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow-estimator==2.5.0
tensorflow-gpu==2.5.3
termcolor==1.1.0
typing-extensions==3.7.4.3
urllib3==1.26.7
Werkzeug==0.16.0
wrapt==1.12.1
tensorboard-plugin-wit==1.8.1
tensorflow-estimator==2.8.0
tensorflow-gpu==2.8.4
tensorflow-io-gcs-filesystem==0.30.0
termcolor==2.2.0
typing_extensions==4.4.0
urllib3==1.26.14
waitress==2.1.2
Wave==0.0.2
Werkzeug==2.2.2
wincertstore==0.2
wrapt==1.14.1
zipp==3.12.0
12 changes: 6 additions & 6 deletions speech_features/speech_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def run(self, wavsignal, fs=16000):
:returns: A numpy array of size (NUMFRAMES by numcep * 3) containing features. Each row holds 1 feature vector.
"""
wavsignal = np.array(wavsignal, dtype=np.float)
wavsignal = np.array(wavsignal, dtype=np.float64)
# 获取输入特征
feat_mfcc = mfcc(wavsignal[0], samplerate=self.framesamplerate, winlen=self.winlen,
winstep=self.winstep, numcep=self.numcep, nfilt=self.nfilt, preemph=self.preemph)
Expand All @@ -99,7 +99,7 @@ def __init__(self, framesamplerate=16000, nfilt=26):
super().__init__(framesamplerate)

def run(self, wavsignal, fs=16000):
wavsignal = np.array(wavsignal, dtype=np.float)
wavsignal = np.array(wavsignal, dtype=np.float64)
# 获取输入特征
wav_feature = logfbank(wavsignal, fs, nfilt=self.nfilt)
return wav_feature
Expand Down Expand Up @@ -140,8 +140,8 @@ def run(self, wavsignal, fs=16000):
# wav_length = wav_arr.shape[1]

range0_end = int(len(wavsignal[0]) / fs * 1000 - time_window) // 10 + 1 # 计算循环终止的位置,也就是最终生成的窗数
data_input = np.zeros((range0_end, window_length // 2), dtype=np.float) # 用于存放最终的频率特征数据
data_line = np.zeros((1, window_length), dtype=np.float)
data_input = np.zeros((range0_end, window_length // 2), dtype=np.float64) # 用于存放最终的频率特征数据
data_line = np.zeros((1, window_length), dtype=np.float64)

for i in range(0, range0_end):
p_start = i * 160
Expand Down Expand Up @@ -192,8 +192,8 @@ def run(self, wavsignal, fs=16000):
# wav_length = wav_arr.shape[1]

range0_end = int(len(wavsignal[0]) / fs * 1000 - time_window) // 10 + 1 # 计算循环终止的位置,也就是最终生成的窗数
data_input = np.zeros((range0_end, window_length // 2), dtype=np.float) # 用于存放最终的频率特征数据
data_line = np.zeros((1, window_length), dtype=np.float)
data_input = np.zeros((range0_end, window_length // 2), dtype=np.float64) # 用于存放最终的频率特征数据
data_line = np.zeros((1, window_length), dtype=np.float64)

for i in range(0, range0_end):
p_start = i * 160
Expand Down
6 changes: 3 additions & 3 deletions speech_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def _data_generator(self, batch_size, data_loader):
数据生成器函数,用于Keras的generator_fit训练
batch_size: 一次产生的数据量
"""
labels = np.zeros((batch_size, 1), dtype=np.float)
labels = np.zeros((batch_size, 1), dtype=np.float64)
data_count = data_loader.get_data_count()
index = 0

while True:
X = np.zeros((batch_size,) + self.speech_model.input_shape, dtype=np.float)
X = np.zeros((batch_size,) + self.speech_model.input_shape, dtype=np.float64)
y = np.zeros((batch_size, self.max_label_length), dtype=np.int16)
input_length = []
label_length = []
Expand Down Expand Up @@ -233,7 +233,7 @@ def recognize_speech(self, wavsignal, fs):
"""
# 获取输入特征
data_input = self.speech_features.run(wavsignal, fs)
data_input = np.array(data_input, dtype=np.float)
data_input = np.array(data_input, dtype=np.float64)
# print(data_input,data_input.shape)
data_input = data_input.reshape(data_input.shape[0], data_input.shape[1], 1)
r1 = self.predict(data_input)
Expand Down

0 comments on commit 5352aff

Please sign in to comment.