forked from neonbjb/tortoise-tts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtts_api_worker.py
197 lines (174 loc) · 8.37 KB
/
tts_api_worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import argparse
import os
import io
from collections import deque
import torch
import torchaudio
from tortoise.api import TextToSpeech
from tortoise.api_fast import TextToSpeech as TextToSpeechFast
from tortoise.utils.audio import load_voices
from tortoise.utils.text import split_and_recombine_text
from aime_api_worker_interface import APIWorkerInterface
MODELS_DIR = './models'
WORKER_JOB_TYPE = "tts_tortoise"
DEFAULT_WORKER_AUTH_KEY = "5317e305b50505ca2b3284b4ae5f65a5"
VERSION = 0
PRESETS = {
'ultra_fast': {'num_autoregressive_samples': 1, 'diffusion_iterations': 10},
'fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 50},
'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200},
'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400},
}
def get_flags():
parser = argparse.ArgumentParser()
parser.add_argument('--use_deepspeed', action='store_true', help='Use deepspeed for speed bump.')
parser.add_argument(
'--kv_cache', type=bool, help='If you disable this please wait for a long time to get the output', default=True
)
parser.add_argument(
'--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True
)
parser.add_argument(
'--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
'should only be specified if you have custom checkpoints.', default=MODELS_DIR
)
parser.add_argument(
'--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
parser.add_argument(
'--produce_debug_state', action='store_true', help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=False
)
parser.add_argument(
'--cvvp_amount', type=float, help='How much the CVVP model should influence the output.'
'Increasing this can in some cases reduce the likelihood of multiple speakers. Defaults to 0 (disabled)', default=.0
)
parser.add_argument(
"--api_server", type=str, default="http://0.0.0.0:7777", help="Address of the AIME API server"
)
parser.add_argument(
"--api_auth_key", type=str , default=DEFAULT_WORKER_AUTH_KEY, required=False, help="API server worker auth key"
)
parser.add_argument(
"--gpu_id", type=int, default=0, required=False, help="ID of the GPU to be used"
)
parser.add_argument(
"--stream", action='store_true', help="Use streaming"
)
return parser.parse_args()
def main():
args = get_flags()
if args.stream:
print('Using Streaming Interface.')
tts = TextToSpeechFast(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
else:
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half, device_only=True)
candidates = 1
api_worker = APIWorkerInterface(args.api_server, WORKER_JOB_TYPE, args.api_auth_key, args.gpu_id, world_size=1, rank=0, gpu_name=torch.cuda.get_device_name(), worker_version=VERSION)
while True:
try:
job_data = api_worker.job_request()
print(f'Processing job {job_data.get("job_id")}...', end='', flush=True)
selected_voice = job_data.get('voice')
preset = job_data.get('preset')
print("Loading voice...")
voice_samples, conditioning_latents = load_voices([selected_voice])
output = {'model_name': 'tortoise_tts'}
if args.stream:
text_chunk_queue = deque(
split_and_recombine_text(job_data.get('text'))
)
counter = 0
while text_chunk_queue:
try:
if api_worker.jobs_canceled and all(api_worker.jobs_canceled):
print('Canceled')
break
if any(api_worker.progress_input_params):
print('New input')
text_chunk_queue.append(api_worker.progress_input_params.pop(0)[0].get('text_input'))
except AttributeError:
print('Legacy mode! Update api worker interface to receive progress input parameters')
stream = tts.tts_stream(
text_chunk_queue.popleft(),
voice_samples=voice_samples,
conditioning_latents=conditioning_latents,
k=candidates,
verbose=True,
use_deterministic_seed=args.seed,
return_deterministic_state=False,
overlap_wav_len=1024,
stream_chunk_size=job_data.get('stream_chunk_size'),
temperature=job_data.get('temperature'),
length_penalty=job_data.get('length_penalty'),
repetition_penalty=job_data.get('repetition_penalty'),
top_p=job_data.get('top_p'),
max_mel_tokens=job_data.get('max_mel_tokens'),
cvvp_amount=job_data.get('cvvp_amount') or args.cvvp_amount,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
**PRESETS.get(preset)
)
for audio_chunk, text_chunk in stream:
try:
if api_worker.jobs_canceled and all(api_worker.jobs_canceled):
print('Canceled')
break
if api_worker.progress_input_params:
print('New input')
text_chunk_queue.append(api_worker.progress_input_params.pop(0)[0].get('text_input'))
except AttributeError:
pass
counter += 1
with io.BytesIO() as buffer:
torchaudio.save(
buffer,
audio_chunk.unsqueeze(0).cpu(),
format='wav',
sample_rate=24000,
)
output['audio_output'] = buffer
output['text_output'] = text_chunk
while True:
if api_worker.progress_data_received:
break
api_worker.send_progress(counter, output)
while True:
if api_worker.progress_data_received:
break
print('Done')
api_worker.send_job_results({'model_name': 'tortoise_tts'})
else:
text = job_data.get('text')
gen, dbg_state = tts.tts_with_preset(text, k=candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
preset=preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount)
with io.BytesIO() as buffer:
torchaudio.save(
buffer,
gen.squeeze(0).cpu(),
format='wav',
sample_rate=24000,
)
output['audio_output'] = buffer
output['text_output'] = text
api_worker.send_job_results(output)
except ValueError as exc:
print('Error', exc)
continue
def split_text(text, max_length=200):
doc = nlp(text)
chunks = []
chunk = []
length = 0
for sent in doc.sents:
sent_length = len(sent.text)
if length + sent_length > max_length:
chunks.append(' '.join(chunk))
chunk = []
length = 0
chunk.append(sent.text)
length += sent_length + 1
if chunk:
chunks.append(' '.join(chunk))
return chunks
if __name__ == '__main__':
main()